Beispiel #1
0
def build_tracking_model(opt, device='/cpu:0'):
    """
    Given the T+1 sequence of input, return T sequence of output.
    """
    model = {}

    rnn_seq_len = opt['rnn_seq_len']
    cnn_filter_size = opt['cnn_filter_size']
    cnn_num_filter = opt['cnn_num_filter']
    cnn_pool_size = opt['cnn_pool_size']
    num_channel = opt['img_channel']
    use_bn = opt['use_batch_norm']
    height = opt['img_height']
    width = opt['img_width']
    weight_decay = opt['weight_decay']
    rnn_hidden_dim = opt['rnn_hidden_dim']
    base_learn_rate = opt['base_learn_rate']
    learn_rate_decay_step = opt['learn_rate_decay_step']
    learn_rate_decay_rate = opt['learn_rate_decay_rate']
    pretrain_model_filename = opt['pretrain_model_filename']
    is_pretrain = opt['is_pretrain']

    with tf.device(get_device_fn(device)):
        phase_train = tf.placeholder('bool')

        # input image [B, T+1, H, W, C]
        anneal_threshold = tf.placeholder(tf.float32, [1])
        imgs = tf.placeholder(
            tf.float32, [None, rnn_seq_len + 1, height, width, num_channel])
        img_shape = tf.shape(imgs)
        batch_size = img_shape[0]

        init_bbox = tf.placeholder(tf.float32, [None, 4])
        init_rnn_state = tf.placeholder(tf.float32, [None, rnn_hidden_dim * 2])
        gt_bbox = tf.placeholder(tf.float32, [None, rnn_seq_len + 1, 4])
        gt_score = tf.placeholder(tf.float32, [None, rnn_seq_len + 1])
        IOU_score = [None] * (rnn_seq_len + 1)
        IOU_score[0] = 1

        model['imgs'] = imgs
        model['gt_bbox'] = gt_bbox
        model['gt_score'] = gt_score
        model['init_bbox'] = init_bbox
        model['init_rnn_state'] = init_rnn_state
        model['phase_train'] = phase_train
        model['anneal_threshold'] = anneal_threshold

        # define a CNN model
        cnn_filter = cnn_filter_size
        cnn_nlayer = len(cnn_filter)
        cnn_channel = [num_channel] + cnn_num_filter
        cnn_pool = cnn_pool_size
        cnn_act = [tf.nn.relu] * cnn_nlayer
        cnn_use_bn = [use_bn] * cnn_nlayer

        # load pretrained model
        if is_pretrain:
            h5f = h5py.File(pretrain_model_filename, 'r')

            # for key, value in h5f.iteritems():
            #     print key, value

            cnn_init_w = [{'w': h5f['cnn_w_{}'.format(ii)][:],
                           'b': h5f['cnn_b_{}'.format(ii)][:]}
                          for ii in xrange(cnn_nlayer)]

            for ii in xrange(cnn_nlayer):
                for tt in xrange(3 * rnn_seq_len):
                    for w in ['beta', 'gamma']:
                        cnn_init_w[ii]['{}_{}'.format(w, tt)] = h5f[
                            'cnn_{}_0_{}'.format(ii, w)][:]

        cnn_model = nn.cnn(cnn_filter, cnn_channel, cnn_pool, cnn_act,
                           cnn_use_bn, phase_train=phase_train, wd=weight_decay, init_weights=cnn_init_w)

        # define a RNN(LSTM) model
        cnn_subsample = np.array(cnn_pool).prod()
        rnn_h = int(height / cnn_subsample)
        rnn_w = int(width / cnn_subsample)
        rnn_dim = cnn_channel[-1]
        cnn_out_dim = rnn_h * rnn_w * rnn_dim   # input dimension of RNN
        rnn_inp_dim = cnn_out_dim * 3

        rnn_state = [None] * (rnn_seq_len + 1)
        predict_bbox = [None] * (rnn_seq_len + 1)
        predict_score = [None] * (rnn_seq_len + 1)
        predict_bbox[0] = init_bbox
        predict_score[0] = 1

        # rnn_state[-1] = tf.zeros(tf.pack([batch_size, rnn_hidden_dim * 2]))

        # rnn_state[-1] = tf.concat(1, [inverse_transform_box(gt_bbox[:, 0, :],
        #                                                     height, width), tf.zeros(tf.pack([batch_size, rnn_hidden_dim * 2 - 4]))])

        rnn_state[-1] = init_rnn_state

        rnn_hidden_feat = [None] * rnn_seq_len

        rnn_cell = nn.lstm(rnn_inp_dim, rnn_hidden_dim, wd=weight_decay)

        # define two linear mapping MLPs:
        # RNN hidden state -> bbox
        # RNN hidden state -> score
        bbox_mlp_dims = [rnn_hidden_dim, 4]
        bbox_mlp_act = [None]

        bbox_mlp = nn.mlp(bbox_mlp_dims, bbox_mlp_act, add_bias=True,
                          phase_train=phase_train, wd=weight_decay)

        score_mlp_dims = [rnn_hidden_dim, 1]
        score_mlp_act = [tf.sigmoid]

        score_mlp = nn.mlp(score_mlp_dims, score_mlp_act,
                           add_bias=True, phase_train=phase_train, wd=weight_decay)

        # training through time
        for tt in xrange(rnn_seq_len):
            # extract global CNN feature map of the current frame
            h_cnn_global_now = cnn_model(imgs[:, tt, :, :, :])
            cnn_global_feat_now = h_cnn_global_now[-1]
            cnn_global_feat_now = tf.stop_gradient(
                cnn_global_feat_now)  # fix CNN during training
            model['cnn_global_feat_now'] = cnn_global_feat_now

            # extract ROI CNN feature map of the current frame
            use_pred_bbox = tf.to_float(
                tf.less(tf.random_uniform([1]), anneal_threshold))
            x1, y1, x2, y2 = tf.split(
                1, 4, use_pred_bbox * predict_bbox[tt] + (1 - use_pred_bbox) * gt_bbox[:, tt, :])
            idx_map = get_idx_map(tf.pack([batch_size, height, width]))
            mask_map = get_filled_box_idx(idx_map, tf.concat(
                1, [y1, x1]), tf.concat(1, [y2, x2]))

            ROI_img = []
            for cc in xrange(num_channel):
                ROI_img.append(imgs[:, tt, :, :, cc] * mask_map)

            h_cnn_roi_now = cnn_model(
                tf.transpose(tf.pack(ROI_img), [1, 2, 3, 0]))
            cnn_roi_feat_now = h_cnn_roi_now[-1]
            cnn_roi_feat_now = tf.stop_gradient(
                cnn_roi_feat_now)   # fix CNN during training
            model['cnn_roi_feat_now'] = cnn_roi_feat_now

            # extract global CNN feature map of the next frame
            h_cnn_global_next = cnn_model(imgs[:, tt + 1, :, :, :])
            cnn_global_feat_next = h_cnn_global_next[-1]
            cnn_global_feat_next = tf.stop_gradient(
                cnn_global_feat_next)  # fix CNN during training
            model['cnn_global_feat_next'] = cnn_global_feat_next

            # going through a RNN
            # RNN input = global CNN feat map + ROI CNN feat map
            rnn_input = tf.concat(1, [tf.reshape(cnn_global_feat_now, [-1, cnn_out_dim]), tf.reshape(
                cnn_roi_feat_now, [-1, cnn_out_dim]), tf.reshape(cnn_global_feat_next, [-1, cnn_out_dim])])

            rnn_state[tt], _, _, _ = rnn_cell(rnn_input, rnn_state[tt - 1])
            rnn_hidden_feat[tt] = tf.slice(
                rnn_state[tt], [0, rnn_hidden_dim], [-1, rnn_hidden_dim])

            # predict bbox and score
            raw_predict_bbox = bbox_mlp(rnn_hidden_feat[tt])[0]
            predict_bbox[
                tt + 1] = transform_box(raw_predict_bbox, height, width)
            predict_score[
                tt + 1] = score_mlp(rnn_hidden_feat[tt])[-1]

            # compute IOU
            IOU_score[
                tt + 1] = compute_IOU(predict_bbox[tt + 1], gt_bbox[:, tt + 1, :])

        model['final_rnn_state'] = rnn_state[rnn_seq_len-1]

        # # [B, T, 4]
        # predict_bbox_reshape = tf.concat(
        #     1, [tf.expand_dims(tmp, 1) for tmp in predict_bbox[:-1]])
        # # [B, T]
        # IOU_score = f_iou_box(predict_bbox_reshape[:, :, 0: 1], predict_bbox_reshape[
        #                       :, :, 2: 3], gt_bbox[:, :, 0: 1], gt_bbox[:, :, 2: 3])

        predict_bbox = tf.transpose(tf.pack(predict_bbox[1:]), [1, 0, 2])

        model['IOU_score'] = tf.transpose(tf.pack(IOU_score[1:]), [1, 0, 2])

        # model['IOU_score'] = IOU_score
        model['predict_bbox'] = predict_bbox
        model['predict_score'] = tf.transpose(tf.pack(predict_score[1:]))

        # compute IOU loss
        batch_size_f = tf.to_float(batch_size)
        rnn_seq_len_f = tf.to_float(rnn_seq_len)
        # IOU_loss = tf.reduce_sum(gt_score * (- tf.concat(1, IOU_score))) / (batch_size_f * rnn_seq_len_f)

        valid_seq_length = tf.reduce_sum(gt_score[:, 1:], [1])
        valid_seq_length = tf.maximum(1.0, valid_seq_length)
        IOU_loss = gt_score[:, 1:] * (- tf.concat(1, IOU_score[1:]))
        # [B,T] => [B, 1]
        IOU_loss = tf.reduce_sum(IOU_loss, [1])
        # [B, 1]
        IOU_loss /= valid_seq_length
        # [1]
        IOU_loss = tf.reduce_sum(IOU_loss) / batch_size_f

        # compute L2 loss
        # diff_bbox = gt_bbox[:, 1:, :] - predict_bbox
        # diff_x1 = diff_bbox[:, :, 0] / width
        # diff_y1 = diff_bbox[:, :, 1] / height
        # diff_x2 = diff_bbox[:, :, 2] / width
        # diff_y2 = diff_bbox[:, :, 3] / height
        # diff_bbox = tf.transpose(
        #     tf.pack([diff_x1, diff_y1, diff_x2, diff_y2]), [1, 2, 0])

        # L2_loss = tf.reduce_sum(diff_bbox * diff_bbox, [1, 2]) / 4
        # L2_loss /= valid_seq_length
        # L2_loss = tf.reduce_sum(L2_loss) / batch_size_f

        # cross-entropy loss
        cross_entropy = -tf.reduce_sum(gt_score[:, 1:] * tf.log(tf.concat(1, predict_score[1:])) + (
            1 - gt_score[:, 1:]) * tf.log(1 - tf.concat(1, predict_score[1:]))) / (batch_size_f * rnn_seq_len_f)

        model['IOU_loss'] = IOU_loss
        # model['L2_loss'] = L2_loss
        model['CE_loss'] = cross_entropy

        global_step = tf.Variable(0.0)
        eps = 1e-7

        learn_rate = tf.train.exponential_decay(
            base_learn_rate, global_step, learn_rate_decay_step,
            learn_rate_decay_rate, staircase=True)
        model['learn_rate'] = learn_rate

        train_step = GradientClipOptimizer(
            tf.train.AdamOptimizer(learn_rate, epsilon=eps),
            clip=1.0).minimize(IOU_loss + cross_entropy, global_step=global_step)
        model['train_step'] = train_step

    return model
def get_model(opt, device='/cpu:0'):
    model = {}
    inp_height = opt['inp_height']
    inp_width = opt['inp_width']
    inp_depth = opt['inp_depth']
    cnn_filter_size = opt['cnn_filter_size']
    cnn_depth = opt['cnn_depth']
    cnn_pool = opt['cnn_pool']
    mlp_dims = opt['mlp_dims']
    mlp_dropout = opt['mlp_dropout']
    wd = opt['weight_decay']
    base_learn_rate = opt['base_learn_rate']
    learn_rate_decay = opt['learn_rate_decay']
    steps_per_learn_rate_decay = opt['steps_per_learn_rate_decay']

############################
# Input definition
############################
    with tf.device(get_device_fn(device)):
        x = tf.placeholder(
            'float', [None, inp_height, inp_width, inp_depth], name='x')
        phase_train = tf.placeholder('bool', name='phase_train')
        y_gt = tf.placeholder('float', [None], name='y_gt')
        global_step = tf.Variable(0.0)

############################
# Feature CNN definition
############################
        cnn_channels = [inp_depth] + cnn_depth
        cnn_nlayers = len(cnn_filter_size)
        cnn_use_bn = [True] * cnn_nlayers
        cnn_act = [tf.nn.relu] * cnn_nlayers
        cnn = nn.cnn(cnn_filter_size, cnn_channels, cnn_pool, cnn_act,
                     cnn_use_bn, 
                     model=model, phase_train=phase_train, wd=wd, scope='cnn')
        subsample = np.array(cnn_pool).prod()
        cnn_h = inp_height / subsample
        cnn_w = inp_width / subsample
        # feat_dim = cnn_h * cnn_w * cnn_channels[-1]
        feat_dim = cnn_channels[-1]

############################
# MLP definition
############################
        mlp_nlayers = len(mlp_dims)
        mlp_dims = [feat_dim] + mlp_dims
        mlp_dropout_keep = [1 - mlp_dropout] * mlp_nlayers
        mlp_act = [tf.nn.relu] * (mlp_nlayers - 1) + [tf.sigmoid]
        mlp = nn.mlp(mlp_dims, mlp_act, model=model,
                     dropout_keep=mlp_dropout_keep, phase_train=phase_train)

############################
# Computation graph
############################
        f = cnn(x)
        f = nn.avg_pool(f[-1], cnn_h)
        f = tf.reshape(f, [-1, feat_dim])
        y_out = mlp(f)[-1]
        y_out = tf.reshape(y_out, [-1])

############################
# Loss function
############################
        num_ex = tf.shape(y_gt)[0]
        num_ex_f = tf.to_float(num_ex)
        bce = f_bce(y_out, y_gt)
        bce = tf.reduce_sum(bce) / num_ex_f
        tf.add_to_collection('losses', bce)
        total_loss = tf.add_n(tf.get_collection('losses'), name='total_loss')

############################
# Statistics
############################
        y_out_thresh = tf.to_float(y_out > 0.5)
        acc = tf.reduce_sum(
            tf.to_float(tf.equal(y_out_thresh, y_gt))) / num_ex_f

####################
# Optimizer
####################
        learn_rate = tf.train.exponential_decay(
            base_learn_rate, global_step, steps_per_learn_rate_decay,
            learn_rate_decay, staircase=True)
        eps = 1e-7
        train_step = tf.train.AdamOptimizer(learn_rate, epsilon=eps).minimize(
            total_loss, global_step=global_step)

############################
# Computation nodes
############################
        model['x'] = x
        model['y_gt'] = y_gt
        model['phase_train'] = phase_train
        model['y_out'] = y_out
        model['loss'] = total_loss
        model['acc'] = acc
        model['learn_rate'] = learn_rate
        model['train_step'] = train_step

    return model
def get_model(opt, is_training=True):
    """The attention model"""
    log = logger.get()
    model = {}

    timespan = opt['timespan']
    inp_height = opt['inp_height']
    inp_width = opt['inp_width']
    inp_depth = opt['inp_depth']
    padding = opt['padding']
    filter_height = opt['filter_height']
    filter_width = opt['filter_width']

    ctrl_cnn_filter_size = opt['ctrl_cnn_filter_size']
    ctrl_cnn_depth = opt['ctrl_cnn_depth']
    ctrl_cnn_pool = opt['ctrl_cnn_pool']
    ctrl_rnn_hid_dim = opt['ctrl_rnn_hid_dim']

    num_ctrl_mlp_layers = opt['num_ctrl_mlp_layers']
    ctrl_mlp_dim = opt['ctrl_mlp_dim']

    attn_cnn_filter_size = opt['attn_cnn_filter_size']
    attn_cnn_depth = opt['attn_cnn_depth']
    attn_cnn_pool = opt['attn_cnn_pool']
    attn_dcnn_filter_size = opt['attn_dcnn_filter_size']
    attn_dcnn_depth = opt['attn_dcnn_depth']
    attn_dcnn_pool = opt['attn_dcnn_pool']

    mlp_dropout_ratio = opt['mlp_dropout']
    attn_box_padding_ratio = opt['attn_box_padding_ratio']

    wd = opt['weight_decay']
    use_bn = opt['use_bn']
    segm_loss_fn = opt['segm_loss_fn']
    box_loss_fn = opt['box_loss_fn']
    loss_mix_ratio = opt['loss_mix_ratio']
    base_learn_rate = opt['base_learn_rate']
    learn_rate_decay = opt['learn_rate_decay']
    steps_per_learn_rate_decay = opt['steps_per_learn_rate_decay']
    use_knob = opt['use_knob']
    knob_base = opt['knob_base']
    knob_decay = opt['knob_decay']
    steps_per_knob_decay = opt['steps_per_knob_decay']
    knob_box_offset = opt['knob_box_offset']
    knob_segm_offset = opt['knob_segm_offset']
    knob_use_timescale = opt['knob_use_timescale']
    gt_box_ctr_noise = opt['gt_box_ctr_noise']
    gt_box_pad_noise = opt['gt_box_pad_noise']
    gt_segm_noise = opt['gt_segm_noise']

    squash_ctrl_params = opt['squash_ctrl_params']
    fixed_order = opt['fixed_order']
    clip_gradient = opt['clip_gradient']
    fixed_gamma = opt['fixed_gamma']
    num_ctrl_rnn_iter = opt['num_ctrl_rnn_iter']
    num_glimpse_mlp_layers = opt['num_glimpse_mlp_layers']
    pretrain_ctrl_net = opt['pretrain_ctrl_net']
    pretrain_attn_net = opt['pretrain_attn_net']
    pretrain_net = opt['pretrain_net']

    if 'freeze_ctrl_cnn' in opt:
        freeze_ctrl_cnn = opt['freeze_ctrl_cnn']
        freeze_ctrl_rnn = opt['freeze_ctrl_rnn']
        freeze_attn_net = opt['freeze_attn_net']
    else:
        freeze_ctrl_cnn = True
        freeze_ctrl_rnn = True
        freeze_attn_net = True

    if 'freeze_ctrl_mlp' in opt:
        freeze_ctrl_mlp = opt['freeze_ctrl_mlp']
    else:
        freeze_ctrl_mlp = freeze_ctrl_rnn

    if 'fixed_var' in opt:
        fixed_var = opt['fixed_var']
    else:
        fixed_var = False

    if 'dynamic_var' in opt:
        dynamic_var = opt['dynamic_var']
    else:
        dynamic_var = False

    if 'use_iou_box' in opt:
        use_iou_box = opt['use_iou_box']
    else:
        use_iou_box = False

    if 'stop_canvas_grad' in opt:
        stop_canvas_grad = opt['stop_canvas_grad']
    else:
        stop_canvas_grad = True

    if 'add_skip_conn' in opt:
        add_skip_conn = opt['add_skip_conn']
    else:
        add_skip_conn = True

    if 'attn_cnn_skip' in opt:
        attn_cnn_skip = opt['attn_cnn_skip']
    else:
        attn_cnn_skip = [add_skip_conn] * len(attn_cnn_filter_size)

    if 'disable_overwrite' in opt:
        disable_overwrite = opt['disable_overwrite']
    else:
        disable_overwrite = True

    if 'add_d_out' in opt:
        add_d_out = opt['add_d_out']
        add_y_out = opt['add_y_out']
    else:
        add_d_out = False
        add_y_out = False

    if 'attn_add_d_out' in opt:
        attn_add_d_out = opt['attn_add_d_out']
        attn_add_y_out = opt['attn_add_y_out']
        attn_add_inp = opt['attn_add_inp']
        attn_add_canvas = opt['attn_add_canvas']
    else:
        attn_add_d_out = add_d_out
        attn_add_y_out = add_y_out
        attn_add_inp = True
        attn_add_canvas = True

    if 'ctrl_add_d_out' in opt:
        ctrl_add_d_out = opt['ctrl_add_d_out']
        ctrl_add_y_out = opt['ctrl_add_y_out']
        ctrl_add_inp = opt['ctrl_add_inp']
        ctrl_add_canvas = opt['ctrl_add_canvas']
    else:
        ctrl_add_d_out = add_d_out
        ctrl_add_y_out = add_y_out
        ctrl_add_inp = not ctrl_add_d_out
        ctrl_add_canvas = not ctrl_add_d_out

    if 'num_semantic_classes' in opt:
        num_semantic_classes = opt['num_semantic_classes']
    else:
        num_semantic_classes = 1

    rnd_hflip = opt['rnd_hflip']
    rnd_vflip = opt['rnd_vflip']
    rnd_transpose = opt['rnd_transpose']
    rnd_colour = opt['rnd_colour']

    ############################
    # Input definition
    ############################
    # Input image, [B, H, W, D]
    x = tf.placeholder('float', [None, inp_height, inp_width, inp_depth],
                       name='x')
    x_shape = tf.shape(x)
    num_ex = x_shape[0]

    # Groundtruth segmentation, [B, T, H, W]
    y_gt = tf.placeholder('float', [None, timespan, inp_height, inp_width],
                          name='y_gt')

    # Groundtruth confidence score, [B, T]
    s_gt = tf.placeholder('float', [None, timespan], name='s_gt')

    if add_d_out:
        d_in = tf.placeholder('float', [None, inp_height, inp_width, 8],
                              name='d_in')
        model['d_in'] = d_in
    if add_y_out:
        y_in = tf.placeholder(
            'float', [None, inp_height, inp_width, num_semantic_classes],
            name='y_in')
        model['y_in'] = y_in

    # Whether in training stage.
    phase_train = tf.placeholder('bool', name='phase_train')
    phase_train_f = tf.to_float(phase_train)

    model['x'] = x
    model['y_gt'] = y_gt
    model['s_gt'] = s_gt
    model['phase_train'] = phase_train

    # Global step
    if 'freeze_ctrl_cnn' in opt:
        global_step = tf.Variable(0.0, name='global_step')
    else:
        global_step = tf.Variable(0.0)

    ###############################
    # Random input transformation
    ###############################
    # Either add both or add nothing.
    assert (add_d_out and add_y_out) or (not add_d_out and not add_y_out)
    if not add_d_out:
        results = img.random_transformation(x,
                                            padding,
                                            phase_train,
                                            rnd_hflip=rnd_hflip,
                                            rnd_vflip=rnd_vflip,
                                            rnd_transpose=rnd_transpose,
                                            rnd_colour=rnd_colour,
                                            y=y_gt)
        x, y_gt = results['x'], results['y']
    else:
        results = img.random_transformation(x,
                                            padding,
                                            phase_train,
                                            rnd_hflip=rnd_hflip,
                                            rnd_vflip=rnd_vflip,
                                            rnd_transpose=rnd_transpose,
                                            rnd_colour=rnd_colour,
                                            y=y_gt,
                                            d=d_in,
                                            c=y_in)
        x, y_gt, d_in, y_in = results['x'], results['y'], results[
            'd'], results['c']
        model['d_in_trans'] = d_in
        model['y_in_trans'] = y_in
    model['x_trans'] = x
    model['y_gt_trans'] = y_gt

    ############################
    # Canvas: external memory
    ############################
    canvas = tf.zeros(tf.pack([num_ex, inp_height, inp_width, 1]))
    ccnn_inp_depth = 0
    acnn_inp_depth = 0
    if ctrl_add_inp:
        ccnn_inp_depth += inp_depth
    if ctrl_add_canvas:
        ccnn_inp_depth += 1
    if attn_add_inp:
        acnn_inp_depth += inp_depth
    if attn_add_canvas:
        acnn_inp_depth += 1

    if ctrl_add_d_out:
        ccnn_inp_depth += 8
    if ctrl_add_y_out:
        ccnn_inp_depth += num_semantic_classes
    if attn_add_d_out:
        acnn_inp_depth += 8
    if attn_add_y_out:
        acnn_inp_depth += num_semantic_classes

    #############################
    # Controller CNN definition
    #############################
    ccnn_filters = ctrl_cnn_filter_size
    ccnn_nlayers = len(ccnn_filters)
    acnn_nlayers = len(attn_cnn_filter_size)
    ccnn_channels = [ccnn_inp_depth] + ctrl_cnn_depth
    ccnn_pool = ctrl_cnn_pool
    ccnn_act = [tf.nn.relu] * ccnn_nlayers
    ccnn_use_bn = [use_bn] * ccnn_nlayers

    pt = pretrain_net or pretrain_ctrl_net
    if pt:
        log.info(
            'Loading pretrained controller CNN weights from {}'.format(pt))
        with h5py.File(pt, 'r') as h5f:
            ccnn_init_w = [{
                'w': h5f['ctrl_cnn_w_{}'.format(ii)][:],
                'b': h5f['ctrl_cnn_b_{}'.format(ii)][:]
            } for ii in range(ccnn_nlayers)]
            for ii in range(ccnn_nlayers):
                for tt in range(timespan):
                    for w in ['beta', 'gamma']:
                        ccnn_init_w[ii]['{}_{}'.format(
                            w,
                            tt)] = h5f['ctrl_cnn_{}_{}_{}'.format(ii, tt,
                                                                  w)][:]
        ccnn_frozen = [freeze_ctrl_cnn] * ccnn_nlayers
    else:
        ccnn_init_w = None
        ccnn_frozen = [freeze_ctrl_cnn] * ccnn_nlayers

    ccnn = nn.cnn(ccnn_filters,
                  ccnn_channels,
                  ccnn_pool,
                  ccnn_act,
                  ccnn_use_bn,
                  phase_train=phase_train,
                  wd=wd,
                  scope='ctrl_cnn',
                  model=model,
                  init_weights=ccnn_init_w,
                  frozen=ccnn_frozen)
    h_ccnn = [None] * timespan

    ############################
    # Controller RNN definition
    ############################
    ccnn_subsample = np.array(ccnn_pool).prod()
    crnn_h = inp_height / ccnn_subsample
    crnn_w = inp_width / ccnn_subsample
    crnn_dim = ctrl_rnn_hid_dim
    canvas_dim = inp_height * inp_width / (ccnn_subsample**2)

    glimpse_map_dim = crnn_h * crnn_w
    glimpse_feat_dim = ccnn_channels[-1]
    crnn_inp_dim = glimpse_feat_dim

    pt = pretrain_net or pretrain_ctrl_net
    if pt:
        log.info(
            'Loading pretrained controller RNN weights from {}'.format(pt))
        with h5py.File(pt, 'r') as h5f:
            crnn_init_w = {}
            for w in [
                    'w_xi', 'w_hi', 'b_i', 'w_xf', 'w_hf', 'b_f', 'w_xu',
                    'w_hu', 'b_u', 'w_xo', 'w_ho', 'b_o'
            ]:
                key = 'ctrl_lstm_{}'.format(w)
                crnn_init_w[w] = h5f[key][:]
            crnn_frozen = freeze_ctrl_rnn
    else:
        crnn_init_w = None
        crnn_frozen = freeze_ctrl_rnn

    crnn_state = [None] * (timespan + 1)
    crnn_glimpse_map = [None] * timespan
    crnn_g_i = [None] * timespan
    crnn_g_f = [None] * timespan
    crnn_g_o = [None] * timespan
    h_crnn = [None] * timespan
    crnn_state[-1] = tf.zeros(tf.pack([num_ex, crnn_dim * 2]))
    crnn_cell = nn.lstm(crnn_inp_dim,
                        crnn_dim,
                        wd=wd,
                        scope='ctrl_lstm',
                        init_weights=crnn_init_w,
                        frozen=crnn_frozen,
                        model=model)

    ############################
    # Glimpse MLP definition
    ############################
    gmlp_dims = [crnn_dim] * num_glimpse_mlp_layers + [glimpse_map_dim]
    gmlp_act = [tf.nn.relu] * \
        (num_glimpse_mlp_layers - 1) + [tf.nn.softmax]
    gmlp_dropout = None

    pt = pretrain_net or pretrain_ctrl_net
    if pt:
        log.info('Loading pretrained glimpse MLP weights from {}'.format(pt))
        with h5py.File(pt, 'r') as h5f:
            gmlp_init_w = [{
                'w': h5f['glimpse_mlp_w_{}'.format(ii)][:],
                'b': h5f['glimpse_mlp_b_{}'.format(ii)][:]
            } for ii in range(num_glimpse_mlp_layers)]
            gmlp_frozen = [freeze_ctrl_rnn] * num_glimpse_mlp_layers
    else:
        gmlp_init_w = None
        gmlp_frozen = [freeze_ctrl_rnn] * num_glimpse_mlp_layers

    gmlp = nn.mlp(gmlp_dims,
                  gmlp_act,
                  add_bias=True,
                  dropout_keep=gmlp_dropout,
                  phase_train=phase_train,
                  wd=wd,
                  scope='glimpse_mlp',
                  init_weights=gmlp_init_w,
                  frozen=gmlp_frozen,
                  model=model)

    ############################
    # Controller MLP definition
    ############################
    cmlp_dims = [crnn_dim] + [ctrl_mlp_dim] * \
        (num_ctrl_mlp_layers - 1) + [9]
    cmlp_act = [tf.nn.relu] * (num_ctrl_mlp_layers - 1) + [None]
    cmlp_dropout = None

    pt = pretrain_net or pretrain_ctrl_net
    if pt:
        log.info(
            'Loading pretrained controller MLP weights from {}'.format(pt))
        with h5py.File(pt, 'r') as h5f:
            cmlp_init_w = [{
                'w': h5f['ctrl_mlp_w_{}'.format(ii)][:],
                'b': h5f['ctrl_mlp_b_{}'.format(ii)][:]
            } for ii in range(num_ctrl_mlp_layers)]
            cmlp_frozen = [freeze_ctrl_mlp] * num_ctrl_mlp_layers
    else:
        cmlp_init_w = None
        cmlp_frozen = [freeze_ctrl_mlp] * num_ctrl_mlp_layers

    cmlp = nn.mlp(cmlp_dims,
                  cmlp_act,
                  add_bias=True,
                  dropout_keep=cmlp_dropout,
                  phase_train=phase_train,
                  wd=wd,
                  scope='ctrl_mlp',
                  init_weights=cmlp_init_w,
                  frozen=cmlp_frozen,
                  model=model)

    ###########################
    # Attention CNN definition
    ###########################
    acnn_filters = attn_cnn_filter_size
    acnn_nlayers = len(acnn_filters)
    acnn_channels = [acnn_inp_depth] + attn_cnn_depth
    acnn_pool = attn_cnn_pool
    acnn_act = [tf.nn.relu] * acnn_nlayers
    acnn_use_bn = [use_bn] * acnn_nlayers

    pt = pretrain_net or pretrain_attn_net
    if pt:
        log.info('Loading pretrained attention CNN weights from {}'.format(pt))
        with h5py.File(pt, 'r') as h5f:
            acnn_init_w = [{
                'w': h5f['attn_cnn_w_{}'.format(ii)][:],
                'b': h5f['attn_cnn_b_{}'.format(ii)][:]
            } for ii in range(acnn_nlayers)]
            for ii in range(acnn_nlayers):
                for tt in range(timespan):
                    for w in ['beta', 'gamma']:
                        key = 'attn_cnn_{}_{}_{}'.format(ii, tt, w)
                        acnn_init_w[ii]['{}_{}'.format(w, tt)] = h5f[key][:]
        acnn_frozen = [freeze_attn_net] * acnn_nlayers
    else:
        acnn_init_w = None
        acnn_frozen = [freeze_attn_net] * acnn_nlayers

    acnn = nn.cnn(acnn_filters,
                  acnn_channels,
                  acnn_pool,
                  acnn_act,
                  acnn_use_bn,
                  phase_train=phase_train,
                  wd=wd,
                  scope='attn_cnn',
                  model=model,
                  init_weights=acnn_init_w,
                  frozen=acnn_frozen)

    x_patch = [None] * timespan
    h_acnn = [None] * timespan
    h_acnn_last = [None] * timespan

    acnn_subsample = np.array(acnn_pool).prod()
    acnn_h = filter_height / acnn_subsample
    acnn_w = filter_width / acnn_subsample
    core_depth = acnn_channels[-1]
    core_dim = acnn_h * acnn_w * core_depth

    ##########################
    # Score MLP definition
    ##########################
    pt = pretrain_net
    if pt:
        log.info('Loading score mlp weights from {}'.format(pt))
        with h5py.File(pt, 'r') as h5f:
            smlp_init_w = [{
                'w': h5f['score_mlp_w_{}'.format(ii)][:],
                'b': h5f['score_mlp_b_{}'.format(ii)][:]
            } for ii in range(1)]
    else:
        smlp_init_w = None
    smlp = nn.mlp([crnn_dim + core_dim, 1], [tf.sigmoid],
                  wd=wd,
                  scope='score_mlp',
                  init_weights=smlp_init_w,
                  model=model)
    s_out = [None] * timespan

    #############################
    # Attention DCNN definition
    #############################
    adcnn_filters = attn_dcnn_filter_size
    adcnn_nlayers = len(adcnn_filters)
    adcnn_unpool = attn_dcnn_pool
    adcnn_act = [tf.nn.relu] * adcnn_nlayers
    adcnn_channels = [core_depth] + attn_dcnn_depth

    adcnn_bn_nlayers = adcnn_nlayers
    adcnn_use_bn = [use_bn] * adcnn_bn_nlayers + \
        [False] * (adcnn_nlayers - adcnn_bn_nlayers)

    if add_skip_conn:
        adcnn_skip_ch = [0]
        adcnn_channels_rev = acnn_channels[::-1][1:] + [acnn_inp_depth]
        adcnn_skip_rev = attn_cnn_skip[::-1]
        for sk, ch in zip(adcnn_skip_rev, adcnn_channels_rev):
            adcnn_skip_ch.append(ch if sk else 0)
            pass
    else:
        adcnn_skip_ch = None

    pt = pretrain_net or pretrain_attn_net
    if pt:
        log.info(
            'Loading pretrained attention DCNN weights from {}'.format(pt))
        with h5py.File(pt, 'r') as h5f:
            adcnn_init_w = [{
                'w': h5f['attn_dcnn_w_{}'.format(ii)][:],
                'b': h5f['attn_dcnn_b_{}'.format(ii)][:]
            } for ii in range(adcnn_nlayers)]
            for ii in range(adcnn_bn_nlayers):
                for tt in range(timespan):
                    for w in ['beta', 'gamma']:
                        key = 'attn_dcnn_{}_{}_{}'.format(ii, tt, w)
                        adcnn_init_w[ii]['{}_{}'.format(w, tt)] = h5f[key][:]

        adcnn_frozen = [freeze_attn_net] * adcnn_nlayers
    else:
        adcnn_init_w = None
        adcnn_frozen = [freeze_attn_net] * adcnn_nlayers

    adcnn = nn.dcnn(adcnn_filters,
                    adcnn_channels,
                    adcnn_unpool,
                    adcnn_act,
                    use_bn=adcnn_use_bn,
                    skip_ch=adcnn_skip_ch,
                    phase_train=phase_train,
                    wd=wd,
                    model=model,
                    init_weights=adcnn_init_w,
                    frozen=adcnn_frozen,
                    scope='attn_dcnn')
    h_adcnn = [None] * timespan

    ##########################
    # Attention box
    ##########################
    attn_ctr_norm = [None] * timespan
    attn_lg_size = [None] * timespan
    attn_ctr = [None] * timespan
    attn_size = [None] * timespan
    attn_lg_var = [None] * timespan
    attn_lg_gamma = [None] * timespan
    attn_gamma = [None] * timespan
    attn_box_lg_gamma = [None] * timespan
    attn_top_left = [None] * timespan
    attn_bot_right = [None] * timespan
    attn_box = [None] * timespan
    iou_soft_box = [None] * timespan
    const_ones = tf.ones(tf.pack([num_ex, filter_height, filter_width, 1]))
    attn_box_beta = tf.constant([-5.0])
    attn_box_gamma = [None] * timespan

    #############################
    # Groundtruth attention box
    #############################
    # [B, T, 2]
    attn_ctr_gt, attn_size_gt, attn_lg_var_gt, attn_lg_gamma_gt, \
        attn_box_gt, \
        attn_top_left_gt, attn_bot_right_gt = \
        modellib.get_gt_attn(y_gt, filter_height, filter_width,
                         padding_ratio=attn_box_padding_ratio,
                         center_shift_ratio=0.0,
                         min_padding=padding + 4)
    attn_ctr_gt_noise, attn_size_gt_noise, attn_lg_var_gt_noise, \
        attn_lg_gamma_gt_noise, \
        attn_box_gt_noise, \
        attn_top_left_gt_noise, attn_bot_right_gt_noise = \
        modellib.get_gt_attn(y_gt, filter_height, filter_width,
                         padding_ratio=tf.random_uniform(
                             tf.pack([num_ex, timespan, 1]),
                             attn_box_padding_ratio - gt_box_pad_noise,
                             attn_box_padding_ratio + gt_box_pad_noise),
                         center_shift_ratio=tf.random_uniform(
                             tf.pack([num_ex, timespan, 2]),
                             -gt_box_ctr_noise, gt_box_ctr_noise),
                         min_padding=padding + 4)
    attn_ctr_norm_gt = modellib.get_normalized_center(attn_ctr_gt, inp_height,
                                                      inp_width)
    attn_lg_size_gt = modellib.get_normalized_size(attn_size_gt, inp_height,
                                                   inp_width)

    ##########################
    # Groundtruth mix
    ##########################
    grd_match_cum = tf.zeros(tf.pack([num_ex, timespan]))

    # Scale mix ratio on different timesteps.
    if knob_use_timescale:
        gt_knob_time_scale = tf.reshape(
            1.0 + tf.log(1.0 + tf.to_float(tf.range(timespan)) * 3.0),
            [1, timespan, 1])
    else:
        gt_knob_time_scale = tf.ones([1, timespan, 1])

    # Mix in groundtruth box.
    global_step_box = tf.maximum(0.0, global_step - knob_box_offset)
    gt_knob_prob_box = tf.train.exponential_decay(knob_base,
                                                  global_step_box,
                                                  steps_per_knob_decay,
                                                  knob_decay,
                                                  staircase=False)
    gt_knob_prob_box = tf.minimum(1.0, gt_knob_prob_box * gt_knob_time_scale)
    gt_knob_box = tf.to_float(
        tf.random_uniform(tf.pack([num_ex, timespan, 1]), 0, 1.0) <=
        gt_knob_prob_box)
    model['gt_knob_prob_box'] = gt_knob_prob_box[0, 0, 0]

    # Mix in groundtruth segmentation.
    global_step_segm = tf.maximum(0.0, global_step - knob_segm_offset)
    gt_knob_prob_segm = tf.train.exponential_decay(knob_base,
                                                   global_step_segm,
                                                   steps_per_knob_decay,
                                                   knob_decay,
                                                   staircase=False)
    gt_knob_prob_segm = tf.minimum(1.0, gt_knob_prob_segm * gt_knob_time_scale)
    gt_knob_segm = tf.to_float(
        tf.random_uniform(tf.pack([num_ex, timespan, 1]), 0, 1.0) <=
        gt_knob_prob_segm)
    model['gt_knob_prob_segm'] = gt_knob_prob_segm[0, 0, 0]

    ##########################
    # Segmentation output
    ##########################
    y_out_patch = [None] * timespan
    y_out = [None] * timespan
    y_out_lg_gamma = [None] * timespan
    y_out_beta = tf.constant([-5.0])

    ##########################
    # Computation graph
    ##########################
    for tt in range(timespan):
        # Controller CNN
        ccnn_inp_list = []
        acnn_inp_list = []

        if ctrl_add_inp:
            ccnn_inp_list.append(x)
        if attn_add_inp:
            acnn_inp_list.append(x)
        if ctrl_add_canvas:
            ccnn_inp_list.append(canvas)
        if attn_add_canvas:
            acnn_inp_list.append(canvas)
        if ctrl_add_d_out:
            ccnn_inp_list.append(d_in)
        if attn_add_d_out:
            acnn_inp_list.append(d_in)
        if ctrl_add_y_out:
            ccnn_inp_list.append(y_in)
        if attn_add_y_out:
            acnn_inp_list.append(y_in)

        acnn_inp = tf.concat(3, acnn_inp_list)
        ccnn_inp = tf.concat(3, ccnn_inp_list)

        h_ccnn[tt] = ccnn(ccnn_inp)
        _h_ccnn = h_ccnn[tt]
        h_ccnn_last = _h_ccnn[-1]

        # Controller RNN [B, R1]
        crnn_inp = tf.reshape(h_ccnn_last,
                              [-1, glimpse_map_dim, glimpse_feat_dim])
        crnn_state[tt] = [None] * (num_ctrl_rnn_iter + 1)
        crnn_g_i[tt] = [None] * num_ctrl_rnn_iter
        crnn_g_f[tt] = [None] * num_ctrl_rnn_iter
        crnn_g_o[tt] = [None] * num_ctrl_rnn_iter
        h_crnn[tt] = [None] * num_ctrl_rnn_iter
        crnn_state[tt][-1] = tf.zeros(tf.pack([num_ex, crnn_dim * 2]))
        crnn_glimpse_map[tt] = [None] * num_ctrl_rnn_iter
        crnn_glimpse_map[tt][0] = tf.ones(tf.pack([num_ex, glimpse_map_dim, 1
                                                   ])) / glimpse_map_dim
        # Inner glimpse RNN
        for tt2 in range(num_ctrl_rnn_iter):
            crnn_glimpse = tf.reduce_sum(crnn_inp * crnn_glimpse_map[tt][tt2],
                                         [1])
            crnn_state[tt][tt2], crnn_g_i[tt][tt2], crnn_g_f[tt][tt2], \
                crnn_g_o[tt][tt2] = crnn_cell(
                    crnn_glimpse, crnn_state[tt][tt2 - 1])
            h_crnn[tt][tt2] = tf.slice(crnn_state[tt][tt2], [0, crnn_dim],
                                       [-1, crnn_dim])
            h_gmlp = gmlp(h_crnn[tt][tt2])
            if tt2 < num_ctrl_rnn_iter - 1:
                crnn_glimpse_map[tt][tt2 + 1] = tf.expand_dims(h_gmlp[-1], 2)
        ctrl_out = cmlp(h_crnn[tt][-1])[-1]

        attn_ctr_norm[tt] = tf.slice(ctrl_out, [0, 0], [-1, 2])
        attn_lg_size[tt] = tf.slice(ctrl_out, [0, 2], [-1, 2])

        # Restrict to (-1, 1), (-inf, 0)
        if squash_ctrl_params:
            attn_ctr_norm[tt] = tf.tanh(attn_ctr_norm[tt])
            attn_lg_size[tt] = -tf.nn.softplus(attn_lg_size[tt])

        attn_ctr[tt], attn_size[tt] = modellib.get_unnormalized_attn(
            attn_ctr_norm[tt], attn_lg_size[tt], inp_height, inp_width)

        if fixed_var:
            attn_lg_var[tt] = tf.zeros(tf.pack([num_ex, 2]))
        else:
            attn_lg_var[tt] = modellib.get_normalized_var(
                attn_size[tt], filter_height, filter_width)

        if dynamic_var:
            attn_lg_var[tt] = tf.slice(ctrl_out, [0, 4], [-1, 2])

        if fixed_gamma:
            attn_lg_gamma[tt] = tf.constant([0.0])
            y_out_lg_gamma[tt] = tf.constant([2.0])
        else:
            attn_lg_gamma[tt] = tf.slice(ctrl_out, [0, 6], [-1, 1])
            y_out_lg_gamma[tt] = tf.slice(ctrl_out, [0, 8], [-1, 1])

        attn_box_lg_gamma[tt] = tf.slice(ctrl_out, [0, 7], [-1, 1])
        attn_gamma[tt] = tf.reshape(tf.exp(attn_lg_gamma[tt]), [-1, 1, 1, 1])
        attn_box_gamma[tt] = tf.reshape(tf.exp(attn_box_lg_gamma[tt]),
                                        [-1, 1, 1, 1])
        y_out_lg_gamma[tt] = tf.reshape(y_out_lg_gamma[tt], [-1, 1, 1, 1])

        attn_top_left[tt], attn_bot_right[tt] = modellib.get_box_coord(
            attn_ctr[tt], attn_size[tt])

        # Initial filters (predicted)
        filter_y = modellib.get_gaussian_filter(attn_ctr[tt][:, 0],
                                                attn_size[tt][:, 0],
                                                attn_lg_var[tt][:, 0],
                                                inp_height, filter_height)
        filter_x = modellib.get_gaussian_filter(attn_ctr[tt][:, 1],
                                                attn_size[tt][:, 1],
                                                attn_lg_var[tt][:, 1],
                                                inp_width, filter_width)
        filter_y_inv = tf.transpose(filter_y, [0, 2, 1])
        filter_x_inv = tf.transpose(filter_x, [0, 2, 1])

        # Attention box
        attn_box[tt] = modellib.extract_patch(const_ones * attn_box_gamma[tt],
                                              filter_y_inv, filter_x_inv, 1)
        attn_box[tt] = tf.sigmoid(attn_box[tt] + attn_box_beta)
        attn_box[tt] = tf.reshape(attn_box[tt], [-1, 1, inp_height, inp_width])

        # Kick in GT bbox.
        if use_knob:
            if fixed_order:
                attn_ctr_gtm = attn_ctr_gt_noise[:, tt, :]
                # attn_delta_gtm = attn_delta_gt_noise[:, tt, :]
                attn_size_gtm = attn_size_gt_noise[:, tt, :]
            else:
                if use_iou_box:
                    iou_soft_box[tt] = modellib.f_iou_box(
                        tf.expand_dims(attn_top_left[tt], 1),
                        tf.expand_dims(attn_bot_right[tt], 1),
                        attn_top_left_gt, attn_bot_right_gt)
                else:
                    iou_soft_box[tt] = modellib.f_inter(
                        attn_box[tt], attn_box_gt) / \
                        modellib.f_union(attn_box[tt], attn_box_gt, eps=1e-5)
                grd_match = modellib.f_greedy_match(iou_soft_box[tt],
                                                    grd_match_cum)

                # [B, T, 1]
                grd_match = tf.expand_dims(grd_match, 2)
                attn_ctr_gtm = tf.reduce_sum(grd_match * attn_ctr_gt_noise, 1)
                attn_size_gtm = tf.reduce_sum(grd_match * attn_size_gt_noise,
                                              1)

            attn_ctr[tt] = phase_train_f * gt_knob_box[:, tt, 0: 1] * \
                attn_ctr_gtm + \
                (1 - phase_train_f * gt_knob_box[:, tt, 0: 1]) * \
                attn_ctr[tt]
            attn_size[tt] = phase_train_f * gt_knob_box[:, tt, 0: 1] * \
                attn_size_gtm + \
                (1 - phase_train_f * gt_knob_box[:, tt, 0: 1]) * \
                attn_size[tt]

        attn_top_left[tt], attn_bot_right[tt] = modellib.get_box_coord(
            attn_ctr[tt], attn_size[tt])

        filter_y = modellib.get_gaussian_filter(attn_ctr[tt][:, 0],
                                                attn_size[tt][:, 0],
                                                attn_lg_var[tt][:, 0],
                                                inp_height, filter_height)
        filter_x = modellib.get_gaussian_filter(attn_ctr[tt][:, 1],
                                                attn_size[tt][:, 1],
                                                attn_lg_var[tt][:, 1],
                                                inp_width, filter_width)
        filter_y_inv = tf.transpose(filter_y, [0, 2, 1])
        filter_x_inv = tf.transpose(filter_x, [0, 2, 1])

        # Attended patch [B, A, A, D]
        x_patch[tt] = attn_gamma[tt] * modellib.extract_patch(
            acnn_inp, filter_y, filter_x, acnn_inp_depth)

        # CNN [B, A, A, D] => [B, RH2, RW2, RD2]
        h_acnn[tt] = acnn(x_patch[tt])
        h_acnn_last[tt] = h_acnn[tt][-1]
        h_core = tf.reshape(h_acnn_last[tt], [-1, core_dim])
        h_core_img = h_acnn_last[tt]

        # DCNN
        if add_skip_conn:
            h_acnn_rev = h_acnn[tt][::-1][1:] + [x_patch[tt]]
            adcnn_skip = [None]
            for sk, hcnn in zip(adcnn_skip_rev, h_acnn_rev):
                adcnn_skip.append(hcnn if sk else None)
                pass
        else:
            adcnn_skip = None
        h_adcnn[tt] = adcnn(h_core_img, skip=adcnn_skip)
        y_out_patch[tt] = tf.expand_dims(h_adcnn[tt][-1], 1)

        # Output
        y_out[tt] = modellib.extract_patch(h_adcnn[tt][-1], filter_y_inv,
                                           filter_x_inv, 1)
        y_out[tt] = tf.exp(y_out_lg_gamma[tt]) * y_out[tt] + y_out_beta
        y_out[tt] = tf.sigmoid(y_out[tt])
        y_out[tt] = tf.reshape(y_out[tt], [-1, 1, inp_height, inp_width])

        if disable_overwrite:
            y_out[tt] = tf.reshape(1 - canvas,
                                   [-1, 1, inp_height, inp_width]) * y_out[tt]

        # Scoring network
        smlp_inp = tf.concat(1, [h_crnn[tt][-1], h_core])
        s_out[tt] = smlp(smlp_inp)[-1]

        # Here is the knob kick in GT segmentations at this timestep.
        # [B, N, 1, 1]
        if use_knob:
            _gt_knob_segm = tf.expand_dims(
                tf.expand_dims(gt_knob_segm[:, tt, 0:1], 2), 3)

            if fixed_order:
                _y_out = tf.expand_dims(y_gt[:, tt, :, :], 3)
            else:
                grd_match = tf.expand_dims(grd_match, 3)
                _y_out = tf.expand_dims(tf.reduce_sum(grd_match * y_gt, 1), 3)
            # Add independent uniform noise to groundtruth.
            _noise = tf.random_uniform(
                tf.pack([num_ex, inp_height, inp_width, 1]), 0, gt_segm_noise)
            _y_out = _y_out - _y_out * _noise
            _y_out = phase_train_f * _gt_knob_segm * _y_out + \
                (1 - phase_train_f * _gt_knob_segm) * \
                tf.reshape(y_out[tt], [-1, inp_height, inp_width, 1])
        else:
            _y_out = tf.reshape(y_out[tt], [-1, inp_height, inp_width, 1])
        y_out_last = _y_out
        canvas = tf.maximum(_y_out, canvas)
        if stop_canvas_grad:
            canvas = tf.stop_gradient(canvas)
            y_out_last = tf.stop_gradient(y_out_last)

    #########################
    # Model outputs
    #########################
    s_out = tf.concat(1, s_out)
    model['s_out'] = s_out
    y_out = tf.concat(1, y_out)
    model['y_out'] = y_out
    y_out_patch = tf.concat(1, y_out_patch)
    model['y_out_patch'] = y_out_patch
    attn_box = tf.concat(1, attn_box)
    model['attn_box'] = attn_box
    x_patch = tf.concat(
        1, [tf.expand_dims(x_patch[tt], 1) for tt in range(timespan)])
    model['x_patch'] = x_patch

    attn_top_left = tf.concat(
        1, [tf.expand_dims(tmp, 1) for tmp in attn_top_left])
    attn_bot_right = tf.concat(
        1, [tf.expand_dims(tmp, 1) for tmp in attn_bot_right])
    attn_ctr = tf.concat(1, [tf.expand_dims(tmp, 1) for tmp in attn_ctr])
    attn_size = tf.concat(1, [tf.expand_dims(tmp, 1) for tmp in attn_size])
    attn_lg_gamma = tf.concat(
        1, [tf.expand_dims(tmp, 1) for tmp in attn_lg_gamma])
    attn_box_lg_gamma = tf.concat(
        1, [tf.expand_dims(tmp, 1) for tmp in attn_box_lg_gamma])
    y_out_lg_gamma = tf.concat(
        1, [tf.expand_dims(tmp, 1) for tmp in y_out_lg_gamma])
    model['attn_ctr'] = attn_ctr
    model['attn_size'] = attn_size
    model['attn_top_left'] = attn_top_left
    model['attn_bot_right'] = attn_bot_right
    model['attn_ctr_gt'] = attn_ctr_gt
    model['attn_size_gt'] = attn_size_gt
    model['attn_top_left_gt'] = attn_top_left_gt
    model['attn_bot_right_gt'] = attn_bot_right_gt
    model['attn_box_gt'] = attn_box_gt
    attn_ctr_norm = tf.concat(
        1, [tf.expand_dims(tmp, 1) for tmp in attn_ctr_norm])
    attn_lg_size = tf.concat(1,
                             [tf.expand_dims(tmp, 1) for tmp in attn_lg_size])
    model['attn_ctr_norm'] = attn_ctr_norm
    model['attn_lg_size'] = attn_lg_size
    attn_params = tf.concat(2, [attn_ctr_norm, attn_lg_size])
    attn_params_gt = tf.concat(2, [attn_ctr_norm_gt, attn_lg_size_gt])

    ####################
    # Glimpse
    ####################
    # T * T2 * [H', W'] => [T, T2, H', W']
    crnn_glimpse_map = tf.concat(1, [
        tf.expand_dims(
            tf.concat(1, [
                tf.expand_dims(crnn_glimpse_map[tt][tt2], 1)
                for tt2 in range(num_ctrl_rnn_iter)
            ]), 1) for tt in range(timespan)
    ])
    crnn_glimpse_map = tf.reshape(
        crnn_glimpse_map, [-1, timespan, num_ctrl_rnn_iter, crnn_h, crnn_w])
    model['ctrl_rnn_glimpse_map'] = crnn_glimpse_map

    model['global_step'] = global_step
    if not is_training:
        return model

    #########################
    # Loss function
    #########################
    num_ex_f = tf.to_float(x_shape[0])
    max_num_obj = tf.to_float(timespan)

    ############################
    # Box loss
    ############################
    if fixed_order:
        # [B, T] for fixed order.
        iou_soft_box = modellib.f_iou(attn_box, attn_box_gt, pairwise=False)
    else:
        if use_knob:
            # [B, T, T] for matching.
            iou_soft_box = tf.concat(1, [
                tf.expand_dims(iou_soft_box[tt], 1) for tt in range(timespan)
            ])
        else:
            iou_soft_box = modellib.f_iou(attn_box,
                                          attn_box_gt,
                                          timespan,
                                          pairwise=True)
        # iou_soft_box = modellib.f_iou_pair_new(attn_box, attn_box_gt)

    identity_match = modellib.get_identity_match(num_ex, timespan, s_gt)
    if fixed_order:
        match_box = identity_match
    else:
        match_box = modellib.f_segm_match(iou_soft_box, s_gt)

    model['match_box'] = match_box
    match_sum_box = tf.reduce_sum(match_box, reduction_indices=[2])
    match_count_box = tf.reduce_sum(match_sum_box, reduction_indices=[1])
    match_count_box = tf.maximum(1.0, match_count_box)

    # [B] if fixed order, [B, T] if matching.
    if fixed_order:
        iou_soft_box_mask = iou_soft_box
    else:
        iou_soft_box_mask = tf.reduce_sum(iou_soft_box * match_box, [1])
    iou_soft_box = tf.reduce_sum(iou_soft_box_mask, [1])
    iou_soft_box = tf.reduce_sum(iou_soft_box / match_count_box) / num_ex_f

    if box_loss_fn == 'mse':
        box_loss = modellib.f_match_loss(attn_params,
                                         attn_params_gt,
                                         match_box,
                                         timespan,
                                         modellib.f_squared_err,
                                         model=model)
    elif box_loss_fn == 'huber':
        box_loss = modellib.f_match_loss(attn_params, attn_params_gt,
                                         match_box, timespan, modellib.f_huber)
    elif box_loss_fn == 'iou':
        box_loss = -iou_soft_box
    elif box_loss_fn == 'wt_cov':
        box_loss = -modellib.f_weighted_coverage(iou_soft_box, attn_box_gt)
    elif box_loss_fn == 'bce':
        box_loss_fn = modellib.f_match_loss(y_out, y_gt, match_box, timespan,
                                            f_bce)
    else:
        raise Exception('Unknown box_loss_fn: {}'.format(box_loss_fn))
    model['box_loss'] = box_loss

    box_loss_coeff = tf.constant(1.0)
    model['box_loss_coeff'] = box_loss_coeff
    tf.add_to_collection('losses', box_loss_coeff * box_loss)

    ##############################
    # Segmentation loss
    ##############################
    # IoU (soft)
    iou_soft_pairwise = modellib.f_iou(y_out, y_gt, timespan, pairwise=True)
    real_match = modellib.f_segm_match(iou_soft_pairwise, s_gt)
    if fixed_order:
        iou_soft = modellib.f_iou(y_out, y_gt, pairwise=False)
        match = identity_match
    else:
        iou_soft = iou_soft_pairwise
        match = real_match
    model['match'] = match
    match_sum = tf.reduce_sum(match, reduction_indices=[2])
    match_count = tf.reduce_sum(match_sum, reduction_indices=[1])
    match_count = tf.maximum(1.0, match_count)

    # Weighted coverage (soft)
    wt_cov_soft = modellib.f_weighted_coverage(iou_soft_pairwise, y_gt)
    model['wt_cov_soft'] = wt_cov_soft
    unwt_cov_soft = modellib.f_unweighted_coverage(iou_soft_pairwise,
                                                   match_count)
    model['unwt_cov_soft'] = unwt_cov_soft

    # [B] if fixed order, [B, T] if matching.
    if fixed_order:
        iou_soft_mask = iou_soft
    else:
        iou_soft_mask = tf.reduce_sum(iou_soft * match, [1])
    iou_soft = tf.reduce_sum(iou_soft_mask, [1])
    iou_soft = tf.reduce_sum(iou_soft / match_count) / num_ex_f
    model['iou_soft'] = iou_soft

    if segm_loss_fn == 'iou':
        segm_loss = -iou_soft
    elif segm_loss_fn == 'wt_cov':
        segm_loss = -wt_cov_soft
    elif segm_loss_fn == 'bce':
        segm_loss = f_match_bce(y_out, y_gt, match, timespan)
    else:
        raise Exception('Unknown segm_loss_fn: {}'.format(segm_loss_fn))
    model['segm_loss'] = segm_loss
    segm_loss_coeff = tf.constant(1.0)
    tf.add_to_collection('losses', segm_loss_coeff * segm_loss)

    ####################
    # Score loss
    ####################
    conf_loss = modellib.f_conf_loss(s_out, match, timespan, use_cum_min=True)
    model['conf_loss'] = conf_loss
    tf.add_to_collection('losses', loss_mix_ratio * conf_loss)

    ####################
    # Total loss
    ####################
    total_loss = tf.add_n(tf.get_collection('losses'), name='total_loss')
    model['loss'] = total_loss

    ####################
    # Optimizer
    ####################
    learn_rate = tf.train.exponential_decay(base_learn_rate,
                                            global_step,
                                            steps_per_learn_rate_decay,
                                            learn_rate_decay,
                                            staircase=True)
    model['learn_rate'] = learn_rate
    eps = 1e-7

    optimizer = tf.train.AdamOptimizer(learn_rate, epsilon=eps)
    gvs = optimizer.compute_gradients(total_loss)
    capped_gvs = []
    for grad, var in gvs:
        if grad is not None:
            capped_gvs.append((tf.clip_by_value(grad, -1, 1), var))
        else:
            capped_gvs.append((grad, var))
    train_step = optimizer.apply_gradients(capped_gvs, global_step=global_step)
    model['train_step'] = train_step

    ####################
    # Statistics
    ####################
    # Here statistics (hard measures) is always using matching.
    y_out_hard = tf.to_float(y_out > 0.5)
    iou_hard = modellib.f_iou(y_out_hard, y_gt, timespan, pairwise=True)
    wt_cov_hard = modellib.f_weighted_coverage(iou_hard, y_gt)
    model['wt_cov_hard'] = wt_cov_hard
    unwt_cov_hard = modellib.f_unweighted_coverage(iou_hard, match_count)
    model['unwt_cov_hard'] = unwt_cov_hard
    iou_hard_mask = tf.reduce_sum(iou_hard * real_match, [1])
    iou_hard = tf.reduce_sum(
        tf.reduce_sum(iou_hard_mask, [1]) / match_count) / num_ex_f
    model['iou_hard'] = iou_hard

    dice = modellib.f_dice(y_out_hard, y_gt, timespan, pairwise=True)
    dice = tf.reduce_sum(tf.reduce_sum(
        dice * real_match, reduction_indices=[1, 2]) / match_count) / \
        num_ex_f
    model['dice'] = dice
    model['count_acc'] = modellib.f_count_acc(s_out, s_gt)
    model['dic'] = modellib.f_dic(s_out, s_gt, abs=False)
    model['dic_abs'] = modellib.f_dic(s_out, s_gt, abs=True)

    ################################
    # Controller output statistics
    ################################
    if fixed_gamma:
        attn_lg_gamma_mean = tf.constant([0.0])
        attn_box_lg_gamma_mean = tf.constant([2.0])
        y_out_lg_gamma_mean = tf.constant([2.0])
    else:
        attn_lg_gamma_mean = tf.reduce_sum(attn_lg_gamma) / num_ex_f / timespan
        attn_box_lg_gamma_mean = tf.reduce_sum(
            attn_box_lg_gamma) / num_ex_f / timespan
        y_out_lg_gamma_mean = tf.reduce_sum(
            y_out_lg_gamma) / num_ex_f / timespan
    model['attn_lg_gamma_mean'] = attn_lg_gamma_mean
    model['attn_box_lg_gamma_mean'] = attn_box_lg_gamma_mean
    model['y_out_lg_gamma_mean'] = y_out_lg_gamma_mean

    return model
Beispiel #4
0
def get_model(opt):
  """The box model"""
  log = logger.get()
  model = {}

  timespan = opt['timespan']
  inp_height = opt['inp_height']
  inp_width = opt['inp_width']
  inp_depth = opt['inp_depth']
  padding = opt['padding']
  filter_height = opt['filter_height']
  filter_width = opt['filter_width']

  ctrl_cnn_filter_size = opt['ctrl_cnn_filter_size']
  ctrl_cnn_depth = opt['ctrl_cnn_depth']
  ctrl_cnn_pool = opt['ctrl_cnn_pool']
  ctrl_rnn_hid_dim = opt['ctrl_rnn_hid_dim']

  num_ctrl_mlp_layers = opt['num_ctrl_mlp_layers']
  ctrl_mlp_dim = opt['ctrl_mlp_dim']

  attn_box_padding_ratio = opt['attn_box_padding_ratio']

  wd = opt['weight_decay']
  use_bn = opt['use_bn']
  box_loss_fn = opt['box_loss_fn']
  base_learn_rate = opt['base_learn_rate']
  learn_rate_decay = opt['learn_rate_decay']
  steps_per_learn_rate_decay = opt['steps_per_learn_rate_decay']
  pretrain_cnn = opt['pretrain_cnn']

  if 'pretrain_net' in opt:
    pretrain_net = opt['pretrain_net']
  else:
    pretrain_net = None

  if 'freeze_pretrain_cnn' in opt:
    freeze_pretrain_cnn = opt['freeze_pretrain_cnn']
  else:
    freeze_pretrain_cnn = True

  squash_ctrl_params = opt['squash_ctrl_params']
  clip_gradient = opt['clip_gradient']
  fixed_order = opt['fixed_order']
  num_ctrl_rnn_iter = opt['num_ctrl_rnn_iter']
  num_glimpse_mlp_layers = opt['num_glimpse_mlp_layers']

  if 'fixed_var' in opt:
    fixed_var = opt['fixed_var']
  else:
    fixed_var = True

  if 'use_iou_box' in opt:
    use_iou_box = opt['use_iou_box']
  else:
    use_iou_box = False

  if 'dynamic_var' in opt:
    dynamic_var = opt['dynamic_var']
  else:
    dynamic_var = False

  if 'num_semantic_classes' in opt:
    num_semantic_classes = opt['num_semantic_classes']
  else:
    num_semantic_classes = 1

  if 'add_d_out' in opt:
    add_d_out = opt['add_d_out']
    add_y_out = opt['add_y_out']
  else:
    add_d_out = False
    add_y_out = False

  rnd_hflip = opt['rnd_hflip']
  rnd_vflip = opt['rnd_vflip']
  rnd_transpose = opt['rnd_transpose']
  rnd_colour = opt['rnd_colour']

  ############################
  # Input definition
  ############################
  # Input image, [B, H, W, D]
  x = tf.placeholder(
      'float', [None, inp_height, inp_width, inp_depth], name='x')
  x_shape = tf.shape(x)
  num_ex = x_shape[0]

  # Groundtruth segmentation, [B, T, H, W]
  y_gt = tf.placeholder(
      'float', [None, timespan, inp_height, inp_width], name='y_gt')

  # Groundtruth confidence score, [B, T]
  s_gt = tf.placeholder('float', [None, timespan], name='s_gt')

  if add_d_out:
    d_in = tf.placeholder(
        'float', [None, inp_height, inp_width, 8], name='d_in')
    model['d_in'] = d_in
  if add_y_out:
    y_in = tf.placeholder(
        'float', [None, inp_height, inp_width, num_semantic_classes],
        name='y_in')
    model['y_in'] = y_in
  # Whether in training stage.
  phase_train = tf.placeholder('bool', name='phase_train')
  phase_train_f = tf.to_float(phase_train)
  model['x'] = x
  model['y_gt'] = y_gt
  model['s_gt'] = s_gt
  model['phase_train'] = phase_train

  # Global step
  global_step = tf.Variable(0.0, name='global_step')

  ###############################
  # Random input transformation
  ###############################
  # Either add both or add nothing.
  assert (add_d_out and add_y_out) or (not add_d_out and not add_y_out)
  if not add_d_out:
    results = img.random_transformation(
        x,
        padding,
        phase_train,
        rnd_hflip=rnd_hflip,
        rnd_vflip=rnd_vflip,
        rnd_transpose=rnd_transpose,
        rnd_colour=rnd_colour,
        y=y_gt)
    x, y_gt = results['x'], results['y']
  else:
    results = img.random_transformation(
        x,
        padding,
        phase_train,
        rnd_hflip=rnd_hflip,
        rnd_vflip=rnd_vflip,
        rnd_transpose=rnd_transpose,
        rnd_colour=rnd_colour,
        y=y_gt,
        d=d_in,
        c=y_in)
    x, y_gt, d_in, y_in = results['x'], results['y'], results['d'], results['c']
    model['d_in_trans'] = d_in
    model['y_in_trans'] = y_in
  model['x_trans'] = x
  model['y_gt_trans'] = y_gt

  ############################
  # Canvas: external memory
  ############################
  canvas = tf.zeros(tf.pack([num_ex, inp_height, inp_width, 1]))
  ccnn_inp_depth = inp_depth + 1
  acnn_inp_depth = inp_depth + 1
  if add_d_out:
    ccnn_inp_depth += 8
    acnn_inp_depth += 8
  if add_y_out:
    ccnn_inp_depth += num_semantic_classes
    acnn_inp_depth += num_semantic_classes

  ############################
  # Controller CNN definition
  ############################
  ccnn_filters = ctrl_cnn_filter_size
  ccnn_nlayers = len(ccnn_filters)
  ccnn_channels = [ccnn_inp_depth] + ctrl_cnn_depth
  ccnn_pool = ctrl_cnn_pool
  ccnn_act = [tf.nn.relu] * ccnn_nlayers
  ccnn_use_bn = [use_bn] * ccnn_nlayers
  pt = pretrain_net or pretrain_cnn

  if pt:
    log.info('Loading pretrained weights from {}'.format(pt))
    with h5py.File(pt, 'r') as h5f:
      pt_cnn_nlayers = 0
      # Assuming pt_cnn_nlayers is smaller than or equal to
      # ccnn_nlayers.
      for ii in range(ccnn_nlayers):
        if 'attn_cnn_w_{}'.format(ii) in h5f:
          cnn_prefix = 'attn_'
          log.info('Loading attn_cnn_w_{}'.format(ii))
          log.info('Loading attn_cnn_b_{}'.format(ii))
          pt_cnn_nlayers += 1
        elif 'cnn_w_{}'.format(ii) in h5f:
          cnn_prefix = ''
          log.info('Loading cnn_w_{}'.format(ii))
          log.info('Loading cnn_b_{}'.format(ii))
          pt_cnn_nlayers += 1
        elif 'ctrl_cnn_w_{}'.format(ii) in h5f:
          cnn_prefix = 'ctrl_'
          log.info('Loading ctrl_cnn_w_{}'.format(ii))
          log.info('Loading ctrl_cnn_b_{}'.format(ii))
          pt_cnn_nlayers += 1

      ccnn_init_w = [{
          'w': h5f['{}cnn_w_{}'.format(cnn_prefix, ii)][:],
          'b': h5f['{}cnn_b_{}'.format(cnn_prefix, ii)][:]
      } for ii in range(pt_cnn_nlayers)]
      for ii in range(pt_cnn_nlayers):
        for tt in range(timespan):
          for w in ['beta', 'gamma']:
            ccnn_init_w[ii]['{}_{}'.format(w, tt)] = h5f[
                '{}cnn_{}_{}_{}'.format(cnn_prefix, ii, tt, w)][:]
      ccnn_frozen = [freeze_pretrain_cnn] * pt_cnn_nlayers
      for ii in range(pt_cnn_nlayers, ccnn_nlayers):
        ccnn_init_w.append(None)
        ccnn_frozen.append(False)
  else:
    ccnn_init_w = None
    ccnn_frozen = None

  ccnn = nn.cnn(ccnn_filters,
                ccnn_channels,
                ccnn_pool,
                ccnn_act,
                ccnn_use_bn,
                phase_train=phase_train,
                wd=wd,
                scope='ctrl_cnn',
                model=model,
                init_weights=ccnn_init_w,
                frozen=ccnn_frozen)
  h_ccnn = [None] * timespan

  ############################
  # Controller RNN definition
  ############################
  ccnn_subsample = np.array(ccnn_pool).prod()
  crnn_h = inp_height / ccnn_subsample
  crnn_w = inp_width / ccnn_subsample
  crnn_dim = ctrl_rnn_hid_dim
  canvas_dim = inp_height * inp_width / (ccnn_subsample**2)

  glimpse_map_dim = crnn_h * crnn_w
  glimpse_feat_dim = ccnn_channels[-1]
  crnn_inp_dim = glimpse_feat_dim

  pt = pretrain_net
  if pt:
    log.info('Loading pretrained controller RNN weights from {}'.format(pt))
    h5f = h5py.File(pt, 'r')
    crnn_init_w = {}
    for w in [
        'w_xi', 'w_hi', 'b_i', 'w_xf', 'w_hf', 'b_f', 'w_xu', 'w_hu', 'b_u',
        'w_xo', 'w_ho', 'b_o'
    ]:
      key = 'ctrl_lstm_{}'.format(w)
      crnn_init_w[w] = h5f[key][:]
    crnn_frozen = None
  else:
    crnn_init_w = None
    crnn_frozen = None

  crnn_state = [None] * (timespan + 1)
  crnn_glimpse_map = [None] * timespan
  crnn_g_i = [None] * timespan
  crnn_g_f = [None] * timespan
  crnn_g_o = [None] * timespan
  h_crnn = [None] * timespan
  crnn_state[-1] = tf.zeros(tf.pack([num_ex, crnn_dim * 2]))
  crnn_cell = nn.lstm(
      crnn_inp_dim,
      crnn_dim,
      wd=wd,
      scope='ctrl_lstm',
      init_weights=crnn_init_w,
      frozen=crnn_frozen,
      model=model)

  ############################
  # Glimpse MLP definition
  ############################
  gmlp_dims = [crnn_dim] * num_glimpse_mlp_layers + [glimpse_map_dim]
  gmlp_act = [tf.nn.relu] * \
      (num_glimpse_mlp_layers - 1) + [tf.nn.softmax]
  gmlp_dropout = None

  pt = pretrain_net
  if pt:
    log.info('Loading pretrained glimpse MLP weights from {}'.format(pt))
    h5f = h5py.File(pt, 'r')
    gmlp_init_w = [{
        'w': h5f['glimpse_mlp_w_{}'.format(ii)][:],
        'b': h5f['glimpse_mlp_b_{}'.format(ii)][:]
    } for ii in range(num_glimpse_mlp_layers)]
    gmlp_frozen = None
  else:
    gmlp_init_w = None
    gmlp_frozen = None

  gmlp = nn.mlp(gmlp_dims,
                gmlp_act,
                add_bias=True,
                dropout_keep=gmlp_dropout,
                phase_train=phase_train,
                wd=wd,
                scope='glimpse_mlp',
                init_weights=gmlp_init_w,
                frozen=gmlp_frozen,
                model=model)

  ############################
  # Controller MLP definition
  ############################
  cmlp_dims = [crnn_dim] + [ctrl_mlp_dim] * \
      (num_ctrl_mlp_layers - 1) + [9]
  cmlp_act = [tf.nn.relu] * (num_ctrl_mlp_layers - 1) + [None]
  cmlp_dropout = None

  pt = pretrain_net
  if pt:
    log.info('Loading pretrained controller MLP weights from {}'.format(pt))
    h5f = h5py.File(pt, 'r')
    cmlp_init_w = [{
        'w': h5f['ctrl_mlp_w_{}'.format(ii)][:],
        'b': h5f['ctrl_mlp_b_{}'.format(ii)][:]
    } for ii in range(num_ctrl_mlp_layers)]
    cmlp_frozen = None
  else:
    cmlp_init_w = None
    cmlp_frozen = None

  cmlp = nn.mlp(cmlp_dims,
                cmlp_act,
                add_bias=True,
                dropout_keep=cmlp_dropout,
                phase_train=phase_train,
                wd=wd,
                scope='ctrl_mlp',
                init_weights=cmlp_init_w,
                frozen=cmlp_frozen,
                model=model)

  ##########################
  # Score MLP definition
  ##########################
  pt = pretrain_net
  if pt:
    log.info('Loading score mlp weights from {}'.format(pt))
    h5f = h5py.File(pt, 'r')
    smlp_init_w = [{
        'w': h5f['score_mlp_w_{}'.format(ii)][:],
        'b': h5f['score_mlp_b_{}'.format(ii)][:]
    } for ii in range(1)]
  else:
    smlp_init_w = None
  smlp = nn.mlp([crnn_dim, num_semantic_classes], [None],
                wd=wd,
                scope='score_mlp',
                init_weights=smlp_init_w,
                model=model)
  s_out = [None] * timespan

  ##########################
  # Attention box
  ##########################
  attn_ctr_norm = [None] * timespan
  attn_lg_size = [None] * timespan
  attn_lg_var = [None] * timespan
  attn_ctr = [None] * timespan
  attn_size = [None] * timespan
  attn_top_left = [None] * timespan
  attn_bot_right = [None] * timespan
  attn_box = [None] * timespan
  attn_box_lg_gamma = [None] * timespan
  attn_box_gamma = [None] * timespan
  const_ones = tf.ones(tf.pack([num_ex, filter_height, filter_width, 1]))
  attn_box_beta = tf.constant([-5.0])
  iou_soft_box = [None] * timespan

  #############################
  # Groundtruth attention box
  #############################
  attn_top_left_gt, attn_bot_right_gt, attn_box_gt = modellib.get_gt_box(
      y_gt, padding_ratio=attn_box_padding_ratio, center_shift_ratio=0.0)
  attn_ctr_gt, attn_size_gt = modellib.get_box_ctr_size(attn_top_left_gt,
                                                        attn_bot_right_gt)
  attn_ctr_norm_gt = modellib.get_normalized_center(attn_ctr_gt, inp_height,
                                                    inp_width)
  attn_lg_size_gt = modellib.get_normalized_size(attn_size_gt, inp_height,
                                                 inp_width)

  ##########################
  # Groundtruth mix
  ##########################
  grd_match_cum = tf.zeros(tf.pack([num_ex, timespan]))

  ##########################
  # Computation graph
  ##########################
  for tt in range(timespan):
    # Controller CNN
    ccnn_inp_list = [x, canvas]
    if add_d_out:
      ccnn_inp_list.append(d_in)
    if add_y_out:
      ccnn_inp_list.append(y_in)
    ccnn_inp = tf.concat(3, ccnn_inp_list)
    acnn_inp = ccnn_inp
    h_ccnn[tt] = ccnn(ccnn_inp)
    _h_ccnn = h_ccnn[tt]
    h_ccnn_last = _h_ccnn[-1]

    # Controller RNN [B, R1]
    crnn_inp = tf.reshape(h_ccnn_last, [-1, glimpse_map_dim, glimpse_feat_dim])
    crnn_state[tt] = [None] * (num_ctrl_rnn_iter + 1)
    crnn_g_i[tt] = [None] * num_ctrl_rnn_iter
    crnn_g_f[tt] = [None] * num_ctrl_rnn_iter
    crnn_g_o[tt] = [None] * num_ctrl_rnn_iter
    h_crnn[tt] = [None] * num_ctrl_rnn_iter

    crnn_state[tt][-1] = tf.zeros(tf.pack([num_ex, crnn_dim * 2]))

    crnn_glimpse_map[tt] = [None] * num_ctrl_rnn_iter
    crnn_glimpse_map[tt][0] = tf.ones(tf.pack([num_ex, glimpse_map_dim, 1
                                              ])) / glimpse_map_dim

    # Inner glimpse RNN
    for tt2 in range(num_ctrl_rnn_iter):
      crnn_glimpse = tf.reduce_sum(crnn_inp * crnn_glimpse_map[tt][tt2], [1])
      crnn_state[tt][tt2], crnn_g_i[tt][tt2], crnn_g_f[tt][tt2], \
          crnn_g_o[tt][tt2] = \
          crnn_cell(crnn_glimpse, crnn_state[tt][tt2 - 1])
      h_crnn[tt][tt2] = tf.slice(crnn_state[tt][tt2], [0, crnn_dim],
                                 [-1, crnn_dim])
      h_gmlp = gmlp(h_crnn[tt][tt2])
      if tt2 < num_ctrl_rnn_iter - 1:
        crnn_glimpse_map[tt][tt2 + 1] = tf.expand_dims(h_gmlp[-1], 2)

    ctrl_out = cmlp(h_crnn[tt][-1])[-1]

    attn_ctr_norm[tt] = tf.slice(ctrl_out, [0, 0], [-1, 2])
    attn_lg_size[tt] = tf.slice(ctrl_out, [0, 2], [-1, 2])

    # Restrict to (-1, 1), (-inf, 0)
    if squash_ctrl_params:
      attn_ctr_norm[tt] = tf.tanh(attn_ctr_norm[tt])
      attn_lg_size[tt] = -tf.nn.softplus(attn_lg_size[tt])

    attn_ctr[tt], attn_size[tt] = modellib.get_unnormalized_attn(
        attn_ctr_norm[tt], attn_lg_size[tt], inp_height, inp_width)
    attn_box_lg_gamma[tt] = tf.slice(ctrl_out, [0, 7], [-1, 1])

    if fixed_var:
      attn_lg_var[tt] = tf.zeros(tf.pack([num_ex, 2]))
    else:
      attn_lg_var[tt] = modellib.get_normalized_var(attn_size[tt],
                                                    filter_height, filter_width)
    if dynamic_var:
      attn_lg_var[tt] = tf.slice(ctrl_out, [0, 4], [-1, 2])
    attn_box_gamma[tt] = tf.reshape(
        tf.exp(attn_box_lg_gamma[tt]), [-1, 1, 1, 1])
    attn_top_left[tt], attn_bot_right[tt] = modellib.get_box_coord(
        attn_ctr[tt], attn_size[tt])

    # Initial filters (predicted)
    filter_y = modellib.get_gaussian_filter(
        attn_ctr[tt][:, 0], attn_size[tt][:, 0], attn_lg_var[tt][:, 0],
        inp_height, filter_height)
    filter_x = modellib.get_gaussian_filter(
        attn_ctr[tt][:, 1], attn_size[tt][:, 1], attn_lg_var[tt][:, 1],
        inp_width, filter_width)
    filter_y_inv = tf.transpose(filter_y, [0, 2, 1])
    filter_x_inv = tf.transpose(filter_x, [0, 2, 1])

    # Attention box
    attn_box[tt] = attn_box_gamma[tt] * modellib.extract_patch(
        const_ones, filter_y_inv, filter_x_inv, 1)
    attn_box[tt] = tf.sigmoid(attn_box[tt] + attn_box_beta)
    attn_box[tt] = tf.reshape(attn_box[tt], [-1, 1, inp_height, inp_width])

    if fixed_order:
      _y_out = tf.expand_dims(y_gt[:, tt, :, :], 3)
    else:
      if use_iou_box:
        iou_soft_box[tt] = modellib.f_iou_box(
            tf.expand_dims(attn_top_left[tt], 1),
            tf.expand_dims(attn_bot_right[tt], 1), attn_top_left_gt,
            attn_bot_right_gt)
      else:
        iou_soft_box[tt] = modellib.f_inter(
            attn_box[tt], attn_box_gt) / \
            modellib.f_union(attn_box[tt], attn_box_gt, eps=1e-5)
      grd_match = modellib.f_greedy_match(iou_soft_box[tt], grd_match_cum)
      grd_match = tf.expand_dims(tf.expand_dims(grd_match, 2), 3)
      _y_out = tf.expand_dims(tf.reduce_sum(grd_match * y_gt, 1), 3)

    # Add independent uniform noise to groundtruth.
    _noise = tf.random_uniform(
        tf.pack([num_ex, inp_height, inp_width, 1]), 0, 0.3)
    _y_out = _y_out - _y_out * _noise
    canvas = tf.stop_gradient(tf.maximum(_y_out, canvas))
    # canvas += tf.stop_gradient(_y_out)

    # Scoring network
    s_out[tt] = smlp(h_crnn[tt][-1])[-1]

    if num_semantic_classes == 1:
      s_out[tt] = tf.sigmoid(s_out[tt])
    else:
      s_out[tt] = tf.nn.softmax(s_out[tt])

  #########################
  # Model outputs
  #########################
  s_out = tf.concat(1, [tf.expand_dims(tmp, 1) for tmp in s_out])
  if num_semantic_classes == 1:
    s_out = s_out[:, :, 0]
  model['s_out'] = s_out
  attn_box = tf.concat(1, attn_box)
  model['attn_box'] = attn_box
  attn_top_left = tf.concat(1,
                            [tf.expand_dims(tmp, 1) for tmp in attn_top_left])
  attn_bot_right = tf.concat(1,
                             [tf.expand_dims(tmp, 1) for tmp in attn_bot_right])
  attn_ctr = tf.concat(1, [tf.expand_dims(tmp, 1) for tmp in attn_ctr])
  attn_size = tf.concat(1, [tf.expand_dims(tmp, 1) for tmp in attn_size])
  model['attn_top_left'] = attn_top_left
  model['attn_bot_right'] = attn_bot_right
  model['attn_ctr'] = attn_ctr
  model['attn_size'] = attn_size
  model['attn_ctr_norm_gt'] = attn_ctr_norm_gt
  model['attn_lg_size_gt'] = attn_lg_size_gt
  model['attn_top_left_gt'] = attn_top_left_gt
  model['attn_bot_right_gt'] = attn_bot_right_gt
  model['attn_box_gt'] = attn_box_gt
  attn_ctr_norm = tf.concat(1,
                            [tf.expand_dims(tmp, 1) for tmp in attn_ctr_norm])
  attn_lg_size = tf.concat(1, [tf.expand_dims(tmp, 1) for tmp in attn_lg_size])
  model['attn_ctr_norm'] = attn_ctr_norm
  model['attn_lg_size'] = attn_lg_size

  attn_params = tf.concat(2, [attn_ctr_norm, attn_lg_size])
  attn_params_gt = tf.concat(2, [attn_ctr_norm_gt, attn_lg_size_gt])

  #########################
  # Loss function
  #########################
  y_gt_shape = tf.shape(y_gt)
  num_ex_f = tf.to_float(y_gt_shape[0])
  max_num_obj = tf.to_float(y_gt_shape[1])

  ############################
  # Box loss
  ############################
  if fixed_order:
    # [B, T] for fixed order.
    iou_soft_box = modellib.f_iou(attn_box, attn_box_gt, pairwise=False)
  else:
    # [B, T, T] for matching.
    iou_soft_box = tf.concat(
        1, [tf.expand_dims(iou_soft_box[tt], 1) for tt in range(timespan)])

  identity_match = modellib.get_identity_match(num_ex, timespan, s_gt)
  if fixed_order:
    match_box = identity_match
  else:
    match_box = modellib.f_segm_match(iou_soft_box, s_gt)

  model['match_box'] = match_box
  match_sum_box = tf.reduce_sum(match_box, reduction_indices=[2])
  match_count_box = tf.reduce_sum(match_sum_box, reduction_indices=[1])
  match_count_box = tf.maximum(1.0, match_count_box)

  # [B] if fixed order, [B, T] if matching.
  if fixed_order:
    iou_soft_box_mask = iou_soft_box
  else:
    iou_soft_box_mask = tf.reduce_sum(iou_soft_box * match_box, [1])
  iou_soft_box = tf.reduce_sum(iou_soft_box_mask, [1])
  iou_soft_box = tf.reduce_sum(iou_soft_box / match_count_box) / num_ex_f

  if box_loss_fn == 'mse':
    box_loss = modellib.f_match_loss(
        attn_params,
        attn_params_gt,
        match_box,
        timespan,
        modellib.f_squared_err,
        model=model)
  elif box_loss_fn == 'huber':
    box_loss = modellib.f_match_loss(attn_params, attn_params_gt, match_box,
                                     timespan, modellib.f_huber)
  if box_loss_fn == 'iou':
    box_loss = -iou_soft_box
  elif box_loss_fn == 'wt_iou':
    box_loss = -wt_iou_soft_box
  elif box_loss_fn == 'wt_cov':
    box_loss = -modellib.f_weighted_coverage(iou_soft_box, box_map_gt)
  elif box_loss_fn == 'bce':
    box_loss = modellib.f_match_loss(box_map, box_map_gt, match_box, timespan,
                                     modellib.f_bce)
  else:
    raise Exception('Unknown box_loss_fn: {}'.format(box_loss_fn))
  model['box_loss'] = box_loss

  box_loss_coeff = tf.constant(1.0)
  model['box_loss_coeff'] = box_loss_coeff
  tf.add_to_collection('losses', box_loss_coeff * box_loss)

  ####################
  # Score loss
  ####################
  if num_semantic_classes == 1:
    conf_loss = modellib.f_conf_loss(
        s_out, match_box, timespan, use_cum_min=True)
  else:
    conf_loss = modellib.f_conf_loss(
        1 - s_out[:, :, 0], match_box, timespan, use_cum_min=True)
  model['conf_loss'] = conf_loss
  conf_loss_coeff = tf.constant(1.0)
  tf.add_to_collection('losses', conf_loss_coeff * conf_loss)

  ####################
  # Total loss
  ####################
  total_loss = tf.add_n(tf.get_collection('losses'), name='total_loss')
  model['loss'] = total_loss

  ####################
  # Optimizer
  ####################
  learn_rate = tf.train.exponential_decay(
      base_learn_rate,
      global_step,
      steps_per_learn_rate_decay,
      learn_rate_decay,
      staircase=True)
  model['learn_rate'] = learn_rate
  eps = 1e-7
  optim = tf.train.AdamOptimizer(learn_rate, epsilon=eps)
  gvs = optim.compute_gradients(total_loss)
  capped_gvs = []
  for grad, var in gvs:
    if grad is not None:
      capped_gvs.append((tf.clip_by_value(grad, -1, 1), var))
    else:
      capped_gvs.append((grad, var))
  train_step = optim.apply_gradients(capped_gvs, global_step=global_step)
  model['train_step'] = train_step

  ####################
  # Glimpse
  ####################
  # T * T2 * [B, H' * W'] => [B, T, T2, H', W']
  crnn_glimpse_map = tf.concat(1, [
      tf.expand_dims(
          tf.concat(1, [
              tf.expand_dims(crnn_glimpse_map[tt][tt2], 1)
              for tt2 in range(num_ctrl_rnn_iter)
          ]), 1) for tt in range(timespan)
  ])
  crnn_glimpse_map = tf.reshape(
      crnn_glimpse_map, [-1, timespan, num_ctrl_rnn_iter, crnn_h, crnn_w])
  model['ctrl_rnn_glimpse_map'] = crnn_glimpse_map

  return model
Beispiel #5
0
def get_model(opt, device='/cpu:0'):
    """The attention model"""
    model = {}

    timespan = opt['timespan']
    inp_height = opt['inp_height']
    inp_width = opt['inp_width']
    inp_depth = opt['inp_depth']
    padding = opt['padding']
    filter_height = opt['filter_height']
    filter_width = opt['filter_width']

    ctrl_cnn_filter_size = opt['ctrl_cnn_filter_size']
    ctrl_cnn_depth = opt['ctrl_cnn_depth']
    ctrl_cnn_pool = opt['ctrl_cnn_pool']
    ctrl_rnn_hid_dim = opt['ctrl_rnn_hid_dim']

    num_ctrl_mlp_layers = opt['num_ctrl_mlp_layers']
    ctrl_mlp_dim = opt['ctrl_mlp_dim']

    attn_cnn_filter_size = opt['attn_cnn_filter_size']
    attn_cnn_depth = opt['attn_cnn_depth']
    attn_cnn_pool = opt['attn_cnn_pool']
    attn_dcnn_filter_size = opt['attn_dcnn_filter_size']
    attn_dcnn_depth = opt['attn_dcnn_depth']
    attn_dcnn_pool = opt['attn_dcnn_pool']

    mlp_dropout_ratio = opt['mlp_dropout']

    num_attn_mlp_layers = opt['num_attn_mlp_layers']
    attn_mlp_depth = opt['attn_mlp_depth']
    attn_box_padding_ratio = opt['attn_box_padding_ratio']

    wd = opt['weight_decay']
    use_bn = opt['use_bn']
    segm_loss_fn = opt['segm_loss_fn']
    box_loss_fn = opt['box_loss_fn']
    loss_mix_ratio = opt['loss_mix_ratio']
    base_learn_rate = opt['base_learn_rate']
    learn_rate_decay = opt['learn_rate_decay']
    steps_per_learn_rate_decay = opt['steps_per_learn_rate_decay']
    use_knob = opt['use_knob']
    knob_base = opt['knob_base']
    knob_decay = opt['knob_decay']
    steps_per_knob_decay = opt['steps_per_knob_decay']
    knob_box_offset = opt['knob_box_offset']
    knob_segm_offset = opt['knob_segm_offset']
    knob_use_timescale = opt['knob_use_timescale']
    gt_box_ctr_noise = opt['gt_box_ctr_noise']
    gt_box_pad_noise = opt['gt_box_pad_noise']
    gt_segm_noise = opt['gt_segm_noise']

    squash_ctrl_params = opt['squash_ctrl_params']
    fixed_order = opt['fixed_order']
    clip_gradient = opt['clip_gradient']
    fixed_gamma = opt['fixed_gamma']
    ctrl_rnn_inp_struct = opt['ctrl_rnn_inp_struct']  # dense or attn
    num_ctrl_rnn_iter = opt['num_ctrl_rnn_iter']
    num_glimpse_mlp_layers = opt['num_glimpse_mlp_layers']
    pretrain_ctrl_net = opt['pretrain_ctrl_net']
    pretrain_attn_net = opt['pretrain_attn_net']
    pretrain_net = opt['pretrain_net']
    
    # freeze_ctrl_cnn = True
    # freeze_ctrl_rnn = True
    # freeze_attn_net = True
    freeze_ctrl_cnn = opt['freeze_ctrl_cnn']
    freeze_ctrl_rnn = opt['freeze_ctrl_rnn']
    freeze_attn_net = opt['freeze_attn_net']

    rnd_hflip = opt['rnd_hflip']
    rnd_vflip = opt['rnd_vflip']
    rnd_transpose = opt['rnd_transpose']
    rnd_colour = opt['rnd_colour']

############################
# Input definition
############################
    with tf.device(base.get_device_fn(device)):
        # Input image, [B, H, W, D]
        x = tf.placeholder('float', [None, inp_height, inp_width, inp_depth],
                           name='x')
        x_shape = tf.shape(x)
        num_ex = x_shape[0]

        # Groundtruth segmentation, [B, T, H, W]
        y_gt = tf.placeholder('float', [None, timespan, inp_height, inp_width],
                              name='y_gt')

        # Groundtruth confidence score, [B, T]
        s_gt = tf.placeholder('float', [None, timespan], name='s_gt')

        # Whether in training stage.
        phase_train = tf.placeholder('bool', name='phase_train')
        phase_train_f = tf.to_float(phase_train)
        
        model['x'] = x
        model['y_gt'] = y_gt
        model['s_gt'] = s_gt
        model['phase_train'] = phase_train

        # Global step
        global_step = tf.Variable(0.0, name='global_step')
        # global_step = tf.Variable(0.0)

###############################
# Random input transformation
###############################
        x, y_gt = img.random_transformation(
            x, y_gt, padding, phase_train,
            rnd_hflip=rnd_hflip, rnd_vflip=rnd_vflip,
            rnd_transpose=rnd_transpose, rnd_colour=rnd_colour)
        model['x_trans'] = x
        model['y_gt_trans'] = y_gt

############################
# Canvas: external memory
############################
        canvas = tf.zeros(tf.pack([num_ex, inp_height, inp_width, 1]))
        ccnn_inp_depth = inp_depth + 1
        acnn_inp_depth = inp_depth + 1

############################
# Controller CNN definition
############################
        ccnn_filters = ctrl_cnn_filter_size
        ccnn_nlayers = len(ccnn_filters)
        acnn_nlayers = len(attn_cnn_filter_size)
        ccnn_channels = [ccnn_inp_depth] + ctrl_cnn_depth
        ccnn_pool = ctrl_cnn_pool
        ccnn_act = [tf.nn.relu] * ccnn_nlayers
        ccnn_use_bn = [use_bn] * ccnn_nlayers

        pt = pretrain_net or pretrain_ctrl_net
        if pt:
            log.info('Loading pretrained controller CNN weights from {}'.format(
                pt))
            h5f = h5py.File(pt, 'r')
            ccnn_init_w = [{'w': h5f['ctrl_cnn_w_{}'.format(ii)][:],
                            'b': h5f['ctrl_cnn_b_{}'.format(ii)][:]}
                           for ii in xrange(ccnn_nlayers)]
            for ii in xrange(ccnn_nlayers):
                for tt in xrange(timespan):
                    for w in ['beta', 'gamma']:
                        ccnn_init_w[ii]['{}_{}'.format(w, tt)] = h5f[
                            'ctrl_cnn_{}_{}_{}'.format(ii, tt, w)][:]
            ccnn_frozen = [freeze_ctrl_cnn] * ccnn_nlayers
        else:
            ccnn_init_w = None
            ccnn_frozen = [freeze_ctrl_cnn] * ccnn_nlayers

        ccnn = nn.cnn(ccnn_filters, ccnn_channels, ccnn_pool, ccnn_act,
                      ccnn_use_bn, phase_train=phase_train, wd=wd,
                      scope='ctrl_cnn', model=model, init_weights=ccnn_init_w,
                      frozen=ccnn_frozen)
        h_ccnn = [None] * timespan

############################
# Controller RNN definition
############################
        ccnn_subsample = np.array(ccnn_pool).prod()
        crnn_h = inp_height / ccnn_subsample
        crnn_w = inp_width / ccnn_subsample
        crnn_dim = ctrl_rnn_hid_dim
        canvas_dim = inp_height * inp_width / (ccnn_subsample ** 2)

        glimpse_map_dim = crnn_h * crnn_w
        glimpse_feat_dim = ccnn_channels[-1]
        if ctrl_rnn_inp_struct == 'dense':
            crnn_inp_dim = crnn_h * crnn_w * ccnn_channels[-1]
        elif ctrl_rnn_inp_struct == 'attn':
            crnn_inp_dim = glimpse_feat_dim

        pt = pretrain_net or pretrain_ctrl_net
        if pt:
            log.info('Loading pretrained controller RNN weights from {}'.format(
                pt))
            h5f = h5py.File(pt, 'r')
            crnn_init_w = {}
            for w in ['w_xi', 'w_hi', 'b_i', 'w_xf', 'w_hf', 'b_f', 'w_xu',
                      'w_hu', 'b_u', 'w_xo', 'w_ho', 'b_o']:
                key = 'ctrl_lstm_{}'.format(w)
                crnn_init_w[w] = h5f[key][:]
            crnn_frozen = freeze_ctrl_rnn
        else:
            crnn_init_w = None
            crnn_frozen = freeze_ctrl_rnn

        crnn_state = [None] * (timespan + 1)
        crnn_glimpse_map = [None] * timespan
        crnn_g_i = [None] * timespan
        crnn_g_f = [None] * timespan
        crnn_g_o = [None] * timespan
        h_crnn = [None] * timespan
        crnn_state[-1] = tf.zeros(tf.pack([num_ex, crnn_dim * 2]))
        crnn_cell = nn.lstm(crnn_inp_dim, crnn_dim, wd=wd, scope='ctrl_lstm',
                            init_weights=crnn_init_w, frozen=crnn_frozen,
                            model=model)

############################
# Glimpse MLP definition
############################
        gmlp_dims = [crnn_dim] * num_glimpse_mlp_layers + [glimpse_map_dim]
        gmlp_act = [tf.nn.relu] * \
            (num_glimpse_mlp_layers - 1) + [tf.nn.softmax]
        gmlp_dropout = None

        pt = pretrain_net or pretrain_ctrl_net
        if pt:
            log.info('Loading pretrained glimpse MLP weights from {}'.format(
                pt))
            h5f = h5py.File(pt, 'r')
            gmlp_init_w = [{'w': h5f['glimpse_mlp_w_{}'.format(ii)][:],
                            'b': h5f['glimpse_mlp_b_{}'.format(ii)][:]}
                           for ii in xrange(num_glimpse_mlp_layers)]
            gmlp_frozen = [freeze_ctrl_rnn] * num_glimpse_mlp_layers
        else:
            gmlp_init_w = None
            gmlp_frozen = [freeze_ctrl_rnn] * num_glimpse_mlp_layers

        gmlp = nn.mlp(gmlp_dims, gmlp_act, add_bias=True,
                      dropout_keep=gmlp_dropout,
                      phase_train=phase_train, wd=wd, scope='glimpse_mlp',
                      init_weights=gmlp_init_w, frozen=gmlp_frozen,
                      model=model)

############################
# Controller MLP definition
############################
        cmlp_dims = [crnn_dim] + [ctrl_mlp_dim] * \
            (num_ctrl_mlp_layers - 1) + [9]
        cmlp_act = [tf.nn.relu] * (num_ctrl_mlp_layers - 1) + [None]
        cmlp_dropout = None

        pt = pretrain_net or pretrain_ctrl_net
        if pt:
            log.info('Loading pretrained controller MLP weights from {}'.format(
                pt))
            h5f = h5py.File(pt, 'r')
            cmlp_init_w = [{'w': h5f['ctrl_mlp_w_{}'.format(ii)][:],
                            'b': h5f['ctrl_mlp_b_{}'.format(ii)][:]}
                           for ii in xrange(num_ctrl_mlp_layers)]
            cmlp_frozen = [freeze_ctrl_rnn] * num_ctrl_mlp_layers
        else:
            cmlp_init_w = None
            cmlp_frozen = [freeze_ctrl_rnn] * num_ctrl_mlp_layers

        cmlp = nn.mlp(cmlp_dims, cmlp_act, add_bias=True,
                      dropout_keep=cmlp_dropout,
                      phase_train=phase_train, wd=wd, scope='ctrl_mlp',
                      init_weights=cmlp_init_w, frozen=cmlp_frozen,
                      model=model)

###########################
# Attention CNN definition
###########################
        acnn_filters = attn_cnn_filter_size
        acnn_nlayers = len(acnn_filters)
        acnn_channels = [acnn_inp_depth] + attn_cnn_depth
        acnn_pool = attn_cnn_pool
        acnn_act = [tf.nn.relu] * acnn_nlayers
        acnn_use_bn = [use_bn] * acnn_nlayers

        pt = pretrain_net or pretrain_attn_net
        if pt:
            log.info('Loading pretrained attention CNN weights from {}'.format(
                pt))
            h5f = h5py.File(pt, 'r')
            acnn_init_w = [{'w': h5f['attn_cnn_w_{}'.format(ii)][:],
                            'b': h5f['attn_cnn_b_{}'.format(ii)][:]}
                           for ii in xrange(acnn_nlayers)]
            for ii in xrange(acnn_nlayers):
                for tt in xrange(timespan):
                    for w in ['beta', 'gamma']:
                        key = 'attn_cnn_{}_{}_{}'.format(ii, tt, w)
                        acnn_init_w[ii]['{}_{}'.format(w, tt)] = h5f[key][:]
            acnn_frozen = [freeze_attn_net] * acnn_nlayers
        else:
            acnn_init_w = None
            acnn_frozen = [freeze_attn_net] * acnn_nlayers

        acnn = nn.cnn(acnn_filters, acnn_channels, acnn_pool, acnn_act,
                      acnn_use_bn, phase_train=phase_train, wd=wd,
                      scope='attn_cnn', model=model,
                      init_weights=acnn_init_w,
                      frozen=acnn_frozen)

        x_patch = [None] * timespan
        h_acnn = [None] * timespan
        h_acnn_last = [None] * timespan

############################
# Attention MLP definition
############################
        acnn_subsample = np.array(acnn_pool).prod()
        acnn_h = filter_height / acnn_subsample
        acnn_w = filter_width / acnn_subsample
        amlp_inp_dim = acnn_h * acnn_w * acnn_channels[-1]
        core_depth = attn_mlp_depth
        core_dim = acnn_h * acnn_w * core_depth
        amlp_dims = [amlp_inp_dim] + [core_dim] * num_attn_mlp_layers
        amlp_act = [tf.nn.relu] * num_attn_mlp_layers
        amlp_dropout = None

        pt = pretrain_net or pretrain_attn_net
        if pt:
            log.info('Loading pretrained attention MLP weights from {}'.format(
                pt))
            h5f = h5py.File(pt, 'r')
            amlp_init_w = [{'w': h5f['attn_mlp_w_{}'.format(ii)][:],
                            'b': h5f['attn_mlp_b_{}'.format(ii)][:]}
                           for ii in xrange(num_attn_mlp_layers)]
            amlp_frozen = [freeze_attn_net] * num_attn_mlp_layers
        else:
            amlp_init_w = None
            amlp_frozen = [freeze_attn_net] * num_attn_mlp_layers

        amlp = nn.mlp(amlp_dims, amlp_act, dropout_keep=amlp_dropout,
                      phase_train=phase_train, wd=wd, scope='attn_mlp',
                      init_weights=amlp_init_w,
                      frozen=amlp_frozen,
                      model=model)

##########################
# Score MLP definition
##########################
        pt = pretrain_net
        if pt:
            log.info('Loading score mlp weights from {}'.format(pt))
            h5f = h5py.File(pt, 'r')
            smlp_init_w = [{'w': h5f['score_mlp_w_{}'.format(ii)][:],
                            'b': h5f['score_mlp_b_{}'.format(ii)][:]}
                           for ii in xrange(1)]
        else:
            smlp_init_w = None
        smlp = nn.mlp([crnn_dim + core_dim, 1], [tf.sigmoid], wd=wd,
                      scope='score_mlp', init_weights=smlp_init_w,
                      model=model)
        s_out = [None] * timespan

#############################
# Attention DCNN definition
#############################
        adcnn_filters = attn_dcnn_filter_size
        adcnn_nlayers = len(adcnn_filters)
        adcnn_unpool = attn_dcnn_pool
        adcnn_act = [tf.nn.relu] * adcnn_nlayers
        adcnn_channels = [attn_mlp_depth] + attn_dcnn_depth

        adcnn_bn_nlayers = adcnn_nlayers
        # adcnn_bn_nlayers = adcnn_nlayers - 1
        adcnn_use_bn = [use_bn] * adcnn_bn_nlayers + \
            [False] * (adcnn_nlayers - adcnn_bn_nlayers)
        adcnn_skip_ch = [0] + acnn_channels[:: -1][1:]

        pt = pretrain_net or pretrain_attn_net
        if pt:
            log.info('Loading pretrained attention DCNN weights from {}'.format(
                pt))
            h5f = h5py.File(pt, 'r')
            adcnn_init_w = [{'w': h5f['attn_dcnn_w_{}'.format(ii)][:],
                             'b': h5f['attn_dcnn_b_{}'.format(ii)][:]}
                            for ii in xrange(adcnn_nlayers)]
            for ii in xrange(adcnn_bn_nlayers):
                for tt in xrange(timespan):
                    for w in ['beta', 'gamma']:
                        key = 'attn_dcnn_{}_{}_{}'.format(ii, tt, w)
                        adcnn_init_w[ii]['{}_{}'.format(w, tt)] = h5f[key][:]
            
            adcnn_frozen = [freeze_attn_net] * adcnn_nlayers
        else:
            adcnn_init_w = None
            adcnn_frozen = [freeze_attn_net] * adcnn_nlayers

        adcnn = nn.dcnn(adcnn_filters, adcnn_channels, adcnn_unpool,
                        adcnn_act, use_bn=adcnn_use_bn, skip_ch=adcnn_skip_ch,
                        phase_train=phase_train, wd=wd, model=model,
                        init_weights=adcnn_init_w,
                        frozen=adcnn_frozen,
                        scope='attn_dcnn')
        h_adcnn = [None] * timespan

##########################
# Attention box
##########################
        attn_ctr_norm = [None] * timespan
        attn_lg_size = [None] * timespan
        attn_ctr = [None] * timespan
        attn_size = [None] * timespan
        attn_lg_var = [None] * timespan
        attn_lg_gamma = [None] * timespan
        attn_gamma = [None] * timespan
        attn_box_lg_gamma = [None] * timespan
        attn_top_left = [None] * timespan
        attn_bot_right = [None] * timespan
        attn_box = [None] * timespan
        iou_soft_box = [None] * timespan
        const_ones = tf.ones(tf.pack([num_ex, filter_height, filter_width, 1]))
        attn_box_beta = tf.constant([-5.0])
        attn_box_gamma = [None] * timespan

#############################
# Groundtruth attention box
#############################
        # [B, T, 2]
        attn_ctr_gt, attn_size_gt, attn_lg_var_gt, attn_box_gt, \
            attn_top_left_gt, attn_bot_right_gt = \
            base.get_gt_attn(y_gt,
                             padding_ratio=attn_box_padding_ratio,
                             center_shift_ratio=0.0,
                             min_padding=padding + 4)
        attn_ctr_gt_noise, attn_size_gt_noise, attn_lg_var_gt_noise, \
            attn_box_gt_noise, \
            attn_top_left_gt_noise, attn_bot_right_gt_noise = \
            base.get_gt_attn(y_gt,
                             padding_ratio=tf.random_uniform(
                                 tf.pack([num_ex, timespan, 1]),
                                 attn_box_padding_ratio - gt_box_pad_noise,
                                 attn_box_padding_ratio + gt_box_pad_noise),
                             center_shift_ratio=tf.random_uniform(
                                 tf.pack([num_ex, timespan, 2]),
                                 -gt_box_ctr_noise, gt_box_ctr_noise),
                             min_padding=padding + 4)
        attn_ctr_norm_gt = base.get_normalized_center(
            attn_ctr_gt, inp_height, inp_width)
        attn_lg_size_gt = base.get_normalized_size(
            attn_size_gt, inp_height, inp_width)

##########################
# Groundtruth mix
##########################
        grd_match_cum = tf.zeros(tf.pack([num_ex, timespan]))

        # Scale mix ratio on different timesteps.
        if knob_use_timescale:
            gt_knob_time_scale = tf.reshape(
                1.0 + tf.log(1.0 + tf.to_float(tf.range(timespan)) * 3.0),
                [1, timespan, 1])
        else:
            gt_knob_time_scale = tf.ones([1, timespan, 1])

        # Mix in groundtruth box.
        global_step_box = tf.maximum(0.0, global_step - knob_box_offset)
        gt_knob_prob_box = tf.train.exponential_decay(
            knob_base, global_step_box, steps_per_knob_decay, knob_decay,
            staircase=False)
        gt_knob_prob_box = tf.minimum(
            1.0, gt_knob_prob_box * gt_knob_time_scale)
        gt_knob_box = tf.to_float(tf.random_uniform(
            tf.pack([num_ex, timespan, 1]), 0, 1.0) <= gt_knob_prob_box)
        model['gt_knob_prob_box'] = gt_knob_prob_box[0, 0, 0]

        # Mix in groundtruth segmentation.
        global_step_segm = tf.maximum(0.0, global_step - knob_segm_offset)
        gt_knob_prob_segm = tf.train.exponential_decay(
            knob_base, global_step_segm, steps_per_knob_decay, knob_decay,
            staircase=False)
        gt_knob_prob_segm = tf.minimum(
            1.0, gt_knob_prob_segm * gt_knob_time_scale)
        gt_knob_segm = tf.to_float(tf.random_uniform(
            tf.pack([num_ex, timespan, 1]), 0, 1.0) <= gt_knob_prob_segm)
        model['gt_knob_prob_segm'] = gt_knob_prob_segm[0, 0, 0]

##########################
# Segmentation output
##########################
        y_out = [None] * timespan
        y_out_lg_gamma = [None] * timespan
        y_out_beta = tf.constant([-5.0])

##########################
# Computation graph
##########################
        for tt in xrange(timespan):
            # Controller CNN
            ccnn_inp = tf.concat(3, [x, canvas])
            acnn_inp = ccnn_inp
            h_ccnn[tt] = ccnn(ccnn_inp)
            _h_ccnn = h_ccnn[tt]
            h_ccnn_last = _h_ccnn[-1]

            # Controller RNN [B, R1]
            if ctrl_rnn_inp_struct == 'dense':
                crnn_inp = tf.reshape(h_ccnn_last, [-1, crnn_inp_dim])
                crnn_state[tt], crnn_g_i[tt], crnn_g_f[tt], crnn_g_o[
                    tt] = crnn_cell(crnn_inp, crnn_state[tt - 1])
                h_crnn[tt] = tf.slice(
                    crnn_state[tt], [0, crnn_dim], [-1, crnn_dim])

                ctrl_out = cmlp(h_crnn[tt])[-1]

            elif ctrl_rnn_inp_struct == 'attn':
                crnn_inp = tf.reshape(
                    h_ccnn_last, [-1, glimpse_map_dim, glimpse_feat_dim])
                crnn_state[tt] = [None] * (num_ctrl_rnn_iter + 1)
                crnn_g_i[tt] = [None] * num_ctrl_rnn_iter
                crnn_g_f[tt] = [None] * num_ctrl_rnn_iter
                crnn_g_o[tt] = [None] * num_ctrl_rnn_iter
                h_crnn[tt] = [None] * num_ctrl_rnn_iter

                crnn_state[tt][-1] = tf.zeros(tf.pack([num_ex, crnn_dim * 2]))

                crnn_glimpse_map[tt] = [None] * num_ctrl_rnn_iter
                crnn_glimpse_map[tt][0] = tf.ones(
                    tf.pack([num_ex, glimpse_map_dim, 1])) / glimpse_map_dim

                # Inner glimpse RNN
                for tt2 in xrange(num_ctrl_rnn_iter):
                    crnn_glimpse = tf.reduce_sum(
                        crnn_inp * crnn_glimpse_map[tt][tt2], [1])
                    crnn_state[tt][tt2], crnn_g_i[tt][tt2], crnn_g_f[tt][tt2], \
                        crnn_g_o[tt][tt2] = crnn_cell(
                            crnn_glimpse, crnn_state[tt][tt2 - 1])
                    h_crnn[tt][tt2] = tf.slice(
                        crnn_state[tt][tt2], [0, crnn_dim], [-1, crnn_dim])
                    h_gmlp = gmlp(h_crnn[tt][tt2])
                    if tt2 < num_ctrl_rnn_iter - 1:
                        crnn_glimpse_map[tt][
                            tt2 + 1] = tf.expand_dims(h_gmlp[-1], 2)

                ctrl_out = cmlp(h_crnn[tt][-1])[-1]

            attn_ctr_norm[tt] = tf.slice(ctrl_out, [0, 0], [-1, 2])
            attn_lg_size[tt] = tf.slice(ctrl_out, [0, 2], [-1, 2])

            # Restrict to (-1, 1), (-inf, 0)
            if squash_ctrl_params:
                attn_ctr_norm[tt] = tf.tanh(attn_ctr_norm[tt])
                attn_lg_size[tt] = -tf.nn.softplus(attn_lg_size[tt])

            attn_ctr[tt], attn_size[tt] = base.get_unnormalized_attn(
                attn_ctr_norm[tt], attn_lg_size[tt], inp_height, inp_width)

            attn_lg_var[tt] = tf.zeros(tf.pack([num_ex, 2]))

            if fixed_gamma:
                attn_lg_gamma[tt] = tf.constant([0.0])
                # attn_box_lg_gamma[tt] = tf.constant([2.0])
                y_out_lg_gamma[tt] = tf.constant([2.0])
            else:
                attn_lg_gamma[tt] = tf.slice(ctrl_out, [0, 6], [-1, 1])
                # attn_box_lg_gamma[tt] = tf.slice(ctrl_out, [0, 7], [-1, 1])
                y_out_lg_gamma[tt] = tf.slice(ctrl_out, [0, 8], [-1, 1])

            attn_box_lg_gamma[tt] = tf.slice(ctrl_out, [0, 7], [-1, 1])
            attn_gamma[tt] = tf.reshape(
                tf.exp(attn_lg_gamma[tt]), [-1, 1, 1, 1])
            attn_box_gamma[tt] = tf.reshape(tf.exp(
                attn_box_lg_gamma[tt]), [-1, 1, 1, 1])
            y_out_lg_gamma[tt] = tf.reshape(y_out_lg_gamma[tt], [-1, 1, 1, 1])

            # Initial filters (predicted)
            filter_y = base.get_gaussian_filter(
                attn_ctr[tt][:, 0], attn_size[tt][:, 0],
                attn_lg_var[tt][:, 0], inp_height, filter_height)
            filter_x = base.get_gaussian_filter(
                attn_ctr[tt][:, 1], attn_size[tt][:, 1],
                attn_lg_var[tt][:, 1], inp_width, filter_width)
            filter_y_inv = tf.transpose(filter_y, [0, 2, 1])
            filter_x_inv = tf.transpose(filter_x, [0, 2, 1])

            # Attention box
            attn_box[tt] = base.extract_patch(
                const_ones * attn_box_gamma[tt],
                filter_y_inv, filter_x_inv, 1)
            attn_box[tt] = tf.sigmoid(attn_box[tt] + attn_box_beta)
            attn_box[tt] = tf.reshape(attn_box[tt],
                                      [-1, 1, inp_height, inp_width])

            # Kick in GT bbox.
            if use_knob:
                if fixed_order:
                    attn_ctr_gtm = attn_ctr_gt_noise[:, tt, :]
                    attn_delta_gtm = attn_delta_gt_noise[:, tt, :]
                    attn_size_gtm = attn_size_gt_noise[:, tt, :]
                else:
                    iou_soft_box[tt] = base.f_inter(
                        attn_box[tt], attn_box_gt) / \
                        base.f_union(attn_box[tt], attn_box_gt, eps=1e-5)
                    grd_match = base.f_greedy_match(
                        iou_soft_box[tt], grd_match_cum)

                    # [B, T, 1]
                    grd_match = tf.expand_dims(grd_match, 2)
                    attn_ctr_gtm = tf.reduce_sum(
                        grd_match * attn_ctr_gt_noise, 1)
                    attn_size_gtm = tf.reduce_sum(
                        grd_match * attn_size_gt_noise, 1)

                attn_ctr[tt] = phase_train_f * gt_knob_box[:, tt, 0: 1] * \
                    attn_ctr_gtm + \
                    (1 - phase_train_f * gt_knob_box[:, tt, 0: 1]) * \
                    attn_ctr[tt]
                attn_size[tt] = phase_train_f * gt_knob_box[:, tt, 0: 1] * \
                    attn_size_gtm + \
                    (1 - phase_train_f * gt_knob_box[:, tt, 0: 1]) * \
                    attn_size[tt]

            attn_top_left[tt], attn_bot_right[tt] = base.get_box_coord(
                attn_ctr[tt], attn_size[tt])

            filter_y = base.get_gaussian_filter(
                attn_ctr[tt][:, 0], attn_size[tt][:, 0],
                attn_lg_var[tt][:, 0], inp_height, filter_height)
            filter_x = base.get_gaussian_filter(
                attn_ctr[tt][:, 1], attn_size[tt][:, 1],
                attn_lg_var[tt][:, 1], inp_width, filter_width)
            filter_y_inv = tf.transpose(filter_y, [0, 2, 1])
            filter_x_inv = tf.transpose(filter_x, [0, 2, 1])
            tf.stop_gradient(filter_y)
            tf.stop_gradient(filter_x)
            # Attended patch [B, A, A, D]
            x_patch[tt] = attn_gamma[tt] * base.extract_patch(
                acnn_inp, filter_y, filter_x, acnn_inp_depth)

            # CNN [B, A, A, D] => [B, RH2, RW2, RD2]
            h_acnn[tt] = acnn(x_patch[tt])
            h_acnn_last[tt] = h_acnn[tt][-1]

            # Dense segmentation network [B, R] => [B, M]
            amlp_inp = h_acnn_last[tt]
            amlp_inp = tf.reshape(amlp_inp, [-1, amlp_inp_dim])
            h_core = amlp(amlp_inp)[-1]
            h_core_img = tf.reshape(
                h_core, [-1, acnn_h, acnn_w, attn_mlp_depth])

            # DCNN
            skip = [None] + h_acnn[tt][::-1][1:] + [x_patch[tt]]
            h_adcnn[tt] = adcnn(h_core_img, skip=skip)

            # Output
            y_out[tt] = base.extract_patch(
                h_adcnn[tt][-1], filter_y_inv, filter_x_inv, 1)
            y_out[tt] = tf.exp(y_out_lg_gamma[tt]) * y_out[tt] + y_out_beta
            y_out[tt] = tf.sigmoid(y_out[tt])
            y_out[tt] = tf.reshape(y_out[tt], [-1, 1, inp_height, inp_width])

            # Scoring network
            if ctrl_rnn_inp_struct == 'dense':
                smlp_inp = tf.concat(1, [h_crnn[tt], h_core])
            elif ctrl_rnn_inp_struct == 'attn':
                smlp_inp = tf.concat(1, [h_crnn[tt][-1], h_core])
            s_out[tt] = smlp(smlp_inp)[-1]

            # Here is the knob kick in GT segmentations at this timestep.
            # [B, N, 1, 1]
            if use_knob:
                _gt_knob_segm = tf.expand_dims(
                    tf.expand_dims(gt_knob_segm[:, tt, 0: 1], 2), 3)

                if fixed_order:
                    _y_out = tf.expand_dims(y_gt[:, tt, :, :], 3)
                else:
                    grd_match = tf.expand_dims(grd_match, 3)
                    _y_out = tf.expand_dims(tf.reduce_sum(
                        grd_match * y_gt, 1), 3)
                # Add independent uniform noise to groundtruth.
                _noise = tf.random_uniform(
                    tf.pack([num_ex, inp_height, inp_width, 1]), 0,
                    gt_segm_noise)
                _y_out = _y_out - _y_out * _noise
                _y_out = phase_train_f * _gt_knob_segm * _y_out + \
                    (1 - phase_train_f * _gt_knob_segm) * \
                    tf.reshape(y_out[tt], [-1, inp_height, inp_width, 1])
            else:
                _y_out = tf.reshape(y_out[tt],
                                    [-1, inp_height, inp_width, 1])
            canvas = tf.stop_gradient(tf.maximum(_y_out, canvas))

#########################
# Model outputs
#########################
        s_out = tf.concat(1, s_out)
        model['s_out'] = s_out
        y_out = tf.concat(1, y_out)
        model['y_out'] = y_out
        attn_box = tf.concat(1, attn_box)
        model['attn_box'] = attn_box
        x_patch = tf.concat(1, [tf.expand_dims(x_patch[tt], 1)
                                for tt in xrange(timespan)])
        model['x_patch'] = x_patch

        attn_top_left = tf.concat(1, [tf.expand_dims(tmp, 1)
                                      for tmp in attn_top_left])
        attn_bot_right = tf.concat(1, [tf.expand_dims(tmp, 1)
                                       for tmp in attn_bot_right])
        attn_ctr = tf.concat(1, [tf.expand_dims(tmp, 1)
                                 for tmp in attn_ctr])
        attn_size = tf.concat(1, [tf.expand_dims(tmp, 1)
                                  for tmp in attn_size])
        attn_lg_gamma = tf.concat(1, [tf.expand_dims(tmp, 1)
                                      for tmp in attn_lg_gamma])
        attn_box_lg_gamma = tf.concat(1, [tf.expand_dims(tmp, 1)
                                          for tmp in attn_box_lg_gamma])
        y_out_lg_gamma = tf.concat(1, [tf.expand_dims(tmp, 1)
                                       for tmp in y_out_lg_gamma])
        model['attn_ctr'] = attn_ctr
        model['attn_size'] = attn_size
        model['attn_top_left'] = attn_top_left
        model['attn_bot_right'] = attn_bot_right
        model['attn_box_gt'] = attn_box_gt
        attn_ctr_norm = tf.concat(1, [tf.expand_dims(tmp, 1)
                                      for tmp in attn_ctr_norm])
        attn_lg_size = tf.concat(1, [tf.expand_dims(tmp, 1)
                                     for tmp in attn_lg_size])
        model['attn_ctr_norm'] = attn_ctr_norm
        model['attn_lg_size'] = attn_lg_size
        attn_params = tf.concat(2, [attn_ctr_norm, attn_lg_size])
        attn_params_gt = tf.concat(2, [attn_ctr_norm_gt, attn_lg_size_gt])

#########################
# Loss function
#########################
        y_gt_shape = tf.shape(y_gt)
        num_ex_f = tf.to_float(y_gt_shape[0])
        max_num_obj = tf.to_float(y_gt_shape[1])

############################
# Box loss
############################
        if fixed_order:
            # [B, T] for fixed order.
            iou_soft_box = base.f_iou(attn_box, attn_box_gt, pairwise=False)
        else:
            if use_knob:
                # [B, T, T] for matching.
                iou_soft_box = tf.concat(
                    1, [tf.expand_dims(iou_soft_box[tt], 1)
                        for tt in xrange(timespan)])
            else:
                iou_soft_box = base.f_iou(attn_box, attn_box_gt,
                                          timespan, pairwise=True)

        identity_match = base.get_identity_match(num_ex, timespan, s_gt)
        if fixed_order:
            match_box = identity_match
        else:
            match_box = base.f_segm_match(iou_soft_box, s_gt)

        model['match_box'] = match_box
        match_sum_box = tf.reduce_sum(match_box, reduction_indices=[2])
        match_count_box = tf.reduce_sum(match_sum_box, reduction_indices=[1])
        match_count_box = tf.maximum(1.0, match_count_box)

        # [B] if fixed order, [B, T] if matching.
        if fixed_order:
            iou_soft_box_mask = iou_soft_box
        else:
            iou_soft_box_mask = tf.reduce_sum(iou_soft_box * match_box, [1])
        iou_soft_box = tf.reduce_sum(iou_soft_box_mask, [1])
        iou_soft_box = tf.reduce_sum(iou_soft_box / match_count_box) / num_ex_f

        if box_loss_fn == 'mse':
            box_loss = base.f_match_loss(
                attn_params, attn_params_gt, match_box, timespan,
                base.f_squared_err, model=model)
        elif box_loss_fn == 'huber':
            box_loss = base.f_match_loss(
                attn_params, attn_params_gt, match_box, timespan, base.f_huber)
        elif box_loss_fn == 'iou':
            box_loss = -iou_soft_box
        elif box_loss_fn == 'wt_cov':
            box_loss = -base.f_weighted_coverage(iou_soft_box, attn_box_gt)
        elif box_loss_fn == 'bce':
            box_loss_fn = base.f_match_loss(
                y_out, y_gt, match_box, timespan, f_bce)
        else:
            raise Exception('Unknown box_loss_fn: {}'.format(box_loss_fn))
        model['box_loss'] = box_loss

        box_loss_coeff = tf.constant(1.0)
        model['box_loss_coeff'] = box_loss_coeff
        tf.add_to_collection('losses', box_loss_coeff * box_loss)

##############################
# Segmentation loss
##############################
        # IoU (soft)
        iou_soft_pairwise = base.f_iou(y_out, y_gt, timespan, pairwise=True)
        real_match = base.f_segm_match(iou_soft_pairwise, s_gt)
        if fixed_order:
            iou_soft = base.f_iou(y_out, y_gt, pairwise=False)
            match = identity_match
        else:
            iou_soft = iou_soft_pairwise
            match = real_match
        model['match'] = match
        match_sum = tf.reduce_sum(match, reduction_indices=[2])
        match_count = tf.reduce_sum(match_sum, reduction_indices=[1])
        match_count = tf.maximum(1.0, match_count)

        # Weighted coverage (soft)
        wt_cov_soft = base.f_weighted_coverage(iou_soft_pairwise, y_gt)
        model['wt_cov_soft'] = wt_cov_soft
        unwt_cov_soft = base.f_unweighted_coverage(
            iou_soft_pairwise, match_count)
        model['unwt_cov_soft'] = unwt_cov_soft

        # [B] if fixed order, [B, T] if matching.
        if fixed_order:
            iou_soft_mask = iou_soft
        else:
            iou_soft_mask = tf.reduce_sum(iou_soft * match, [1])
        iou_soft = tf.reduce_sum(iou_soft_mask, [1])
        iou_soft = tf.reduce_sum(iou_soft / match_count) / num_ex_f
        model['iou_soft'] = iou_soft

        if segm_loss_fn == 'iou':
            segm_loss = -iou_soft
        elif segm_loss_fn == 'wt_cov':
            segm_loss = -wt_cov_soft
        elif segm_loss_fn == 'bce':
            segm_loss = f_match_bce(y_out, y_gt, match, timespan)
        else:
            raise Exception('Unknown segm_loss_fn: {}'.format(segm_loss_fn))
        model['segm_loss'] = segm_loss
        segm_loss_coeff = tf.constant(1.0)
        tf.add_to_collection('losses', segm_loss_coeff * segm_loss)

####################
# Score loss
####################
        conf_loss = base.f_conf_loss(s_out, match, timespan, use_cum_min=True)
        model['conf_loss'] = conf_loss
        tf.add_to_collection('losses', loss_mix_ratio * conf_loss)

####################
# Total loss
####################
        total_loss = tf.add_n(tf.get_collection('losses'), name='total_loss')
        model['loss'] = total_loss

####################
# Optimizer
####################
        learn_rate = tf.train.exponential_decay(
            base_learn_rate, global_step, steps_per_learn_rate_decay,
            learn_rate_decay, staircase=True)
        model['learn_rate'] = learn_rate
        eps = 1e-7

        train_step = GradientClipOptimizer(
            tf.train.AdamOptimizer(learn_rate, epsilon=eps),
            clip=clip_gradient).minimize(total_loss, global_step=global_step)
        model['train_step'] = train_step

####################
# Statistics
####################
        # Here statistics (hard measures) is always using matching.
        y_out_hard = tf.to_float(y_out > 0.5)
        iou_hard = base.f_iou(y_out_hard, y_gt, timespan, pairwise=True)
        wt_cov_hard = base.f_weighted_coverage(iou_hard, y_gt)
        model['wt_cov_hard'] = wt_cov_hard
        unwt_cov_hard = base.f_unweighted_coverage(iou_hard, match_count)
        model['unwt_cov_hard'] = unwt_cov_hard
        iou_hard_mask = tf.reduce_sum(iou_hard * real_match, [1])
        iou_hard = tf.reduce_sum(tf.reduce_sum(iou_hard_mask, [1]) /
                                 match_count) / num_ex_f
        model['iou_hard'] = iou_hard

        dice = base.f_dice(y_out_hard, y_gt, timespan, pairwise=True)
        dice = tf.reduce_sum(tf.reduce_sum(
            dice * real_match, reduction_indices=[1, 2]) / match_count) / \
            num_ex_f
        model['dice'] = dice

        model['count_acc'] = base.f_count_acc(s_out, s_gt)
        model['dic'] = base.f_dic(s_out, s_gt, abs=False)
        model['dic_abs'] = base.f_dic(s_out, s_gt, abs=True)

################################
# Controller output statistics
################################
        if fixed_gamma:
            attn_lg_gamma_mean = tf.constant([0.0])
            attn_box_lg_gamma_mean = tf.constant([2.0])
            y_out_lg_gamma_mean = tf.constant([2.0])
        else:
            attn_lg_gamma_mean = tf.reduce_sum(
                attn_lg_gamma) / num_ex_f / timespan
            attn_box_lg_gamma_mean = tf.reduce_sum(
                attn_box_lg_gamma) / num_ex_f / timespan
            y_out_lg_gamma_mean = tf.reduce_sum(
                y_out_lg_gamma) / num_ex_f / timespan
        model['attn_lg_gamma_mean'] = attn_lg_gamma_mean
        model['attn_box_lg_gamma_mean'] = attn_box_lg_gamma_mean
        model['y_out_lg_gamma_mean'] = y_out_lg_gamma_mean

####################
# Debug gradients
####################
        ctrl_mlp_b_grad = tf.gradients(total_loss, model['ctrl_mlp_b_0'])
        model['ctrl_mlp_b_grad'] = ctrl_mlp_b_grad[0]

####################
# Glimpse
####################
        # T * T2 * [H', W'] => [T, T2, H', W']
        if ctrl_rnn_inp_struct == 'attn':
            crnn_glimpse_map = tf.concat(
                1, [tf.expand_dims(tf.concat(
                    1, [tf.expand_dims(crnn_glimpse_map[tt][tt2], 1)
                        for tt2 in xrange(num_ctrl_rnn_iter)]), 1)
                    for tt in xrange(timespan)])
            crnn_glimpse_map = tf.reshape(
                crnn_glimpse_map, [-1, timespan, num_ctrl_rnn_iter, crnn_h,
                                   crnn_w])
            model['ctrl_rnn_glimpse_map'] = crnn_glimpse_map

    return model
Beispiel #6
0
    def build_inference_network(self, x):
        """Build inference part of the network."""
        config = self.config
        is_training = self.is_training

        # Activation functions (combining normalization).
        if config.norm_field == "batch":
            log.info("Using batch normalization")
            log.info("Setting sigma={:.3e}".format(config.sigma_init))
            log.info("Setting sigma learnable={}".format(config.learn_sigma))
            log.info("Setting L1={:.3e}".format(config.l1_reg))
            conv_act_fn = [
                get_bn_act(act=get_tf_fn(aa),
                           is_training=is_training,
                           sigma_init=config.sigma_init,
                           affine=config.norm_affine,
                           l1_reg=config.l1_reg,
                           mask=config.bn_mask[ii],
                           l1_collection=self.l1_collection,
                           learn_sigma=config.learn_sigma,
                           dtype=self.dtype())
                for ii, aa in enumerate(config.conv_act_fn)
            ]
        elif config.norm_field == "batch_ms":
            log.info("Using mean subtracted batch normalization")
            log.info("Setting L1={:.3e}".format(config.l1_reg))
            conv_act_fn = [
                get_bnms_act(act=get_tf_fn(aa),
                             is_training=is_training,
                             affine=config.norm_affine,
                             l1_reg=config.l1_reg,
                             mask=config.bn_mask[ii],
                             l1_collection=self.l1_collection,
                             dtype=self.dtype())
                for ii, aa in enumerate(config.conv_act_fn)
            ]
        elif config.norm_field == "layer":
            log.info("Using layer normalization")
            log.info("Setting sigma={:.3e}".format(config.sigma_init))
            log.info("Setting sigma learnable={}".format(config.learn_sigma))
            log.info("Setting L1={:.3e}".format(config.l1_reg))
            conv_act_fn = [
                get_ln_act(act=get_tf_fn(aa),
                           sigma_init=config.sigma_init,
                           affine=config.norm_affine,
                           l1_reg=config.l1_reg,
                           l1_collection=self.l1_collection,
                           learn_sigma=config.learn_sigma,
                           dtype=self.dtype())
                for ii, aa in enumerate(config.conv_act_fn)
            ]
        elif config.norm_field == "layer_ms":
            log.info("Using mean subtracted layer normalization")
            log.info("Setting L1={:.3e}".format(config.l1_reg))
            conv_act_fn = [
                get_lnms_act(act=get_tf_fn(aa),
                             affine=config.norm_affine,
                             l1_reg=config.l1_reg,
                             l1_collection=self.l1_collection,
                             dtype=self.dtype())
                for ii, aa in enumerate(config.conv_act_fn)
            ]
        elif config.norm_field == "div":
            log.info("Using divisive normalization")
            log.info("Setting sigma={:.3e}".format(config.sigma_init))
            log.info("Setting sigma learnable={}".format(config.learn_sigma))
            log.info("Setting L1={:.3e}".format(config.l1_reg))
            conv_act_fn = [
                get_dn_act(sum_window=config.sum_window[ii],
                           sup_window=config.sup_window[ii],
                           act=get_tf_fn(aa),
                           sigma_init=config.sigma_init,
                           affine=config.norm_affine,
                           l1_reg=config.l1_reg,
                           l1_collection=self.l1_collection,
                           learn_sigma=config.learn_sigma,
                           dtype=self.dtype())
                for ii, aa in enumerate(config.conv_act_fn)
            ]
        elif config.norm_field == "div_ms":
            log.info("Using mean subtracted divisive normalization")
            log.info("Setting L1={:.3e}".format(config.l1_reg))
            conv_act_fn = [
                get_dnms_act(sum_window=config.sum_window[ii],
                             act=get_tf_fn(aa),
                             affine=config.norm_affine,
                             l1_reg=config.l1_reg,
                             l1_collection=self.l1_collection,
                             dtype=self.dtype())
                for ii, aa in enumerate(config.conv_act_fn)
            ]
        elif config.norm_field == "no" or config.norm_field is None:
            log.info("Not using normalization")
            log.info("Setting L1={:.3e}".format(config.l1_reg))
            conv_act_fn = [
                get_reg_act(get_tf_fn(aa),
                            l1_reg=config.l1_reg,
                            l1_collection=self.l1_collection)
                for aa in config.conv_act_fn
            ]
        else:
            raise Exception("Unknown normalization \"{}\"".format(
                config.norm_field))

        # Pooling functions.
        pool_fn = [get_tf_fn(pp) for pp in config.pool_fn]

        # CNN function.
        cnn_fn = lambda x: nn.cnn(x,
                                  config.filter_size,
                                  strides=config.strides,
                                  pool_fn=pool_fn,
                                  pool_size=config.pool_size,
                                  pool_strides=config.pool_strides,
                                  act_fn=conv_act_fn,
                                  dtype=self.dtype(),
                                  add_bias=True,
                                  init_std=config.conv_init_std,
                                  init_method=config.conv_init_method,
                                  wd=config.wd)

        # MLP function.
        mlp_act_fn = [get_tf_fn(aa) for aa in config.mlp_act_fn]
        mlp_fn = lambda x: nn.mlp(x,
                                  config.mlp_dims,
                                  is_training=is_training,
                                  act_fn=mlp_act_fn,
                                  dtype=self.dtype(),
                                  init_std=config.mlp_init_std,
                                  init_method=config.mlp_init_method,
                                  dropout=config.mlp_dropout,
                                  wd=config.wd)

        # Prediction model.
        h = cnn_fn(x)
        # [print(n.name) for n in tf.get_default_graph().as_graph_def().node]
        h = tf.reshape(h, [-1, config.mlp_dims[0]])
        logits = mlp_fn(h)
        return logits
def get_model(opt, device='/cpu:0'):
    """The attention model"""
    model = {}

    timespan = opt['timespan']
    inp_height = opt['inp_height']
    inp_width = opt['inp_width']
    inp_depth = opt['inp_depth']
    padding = opt['padding']
    filter_height = opt['filter_height']
    filter_width = opt['filter_width']

    ctrl_cnn_filter_size = opt['ctrl_cnn_filter_size']
    ctrl_cnn_depth = opt['ctrl_cnn_depth']
    ctrl_cnn_pool = opt['ctrl_cnn_pool']
    ctrl_rnn_hid_dim = opt['ctrl_rnn_hid_dim']

    num_ctrl_mlp_layers = opt['num_ctrl_mlp_layers']
    ctrl_mlp_dim = opt['ctrl_mlp_dim']

    # New parameters for double attention.
    num_ctrl_rnn_iter = opt['num_ctrl_rnn_iter']
    num_glimpse_mlp_layers = opt['num_glimpse_mlp_layers']

    attn_cnn_filter_size = opt['attn_cnn_filter_size']
    attn_cnn_depth = opt['attn_cnn_depth']
    attn_cnn_pool = opt['attn_cnn_pool']
    attn_dcnn_filter_size = opt['attn_dcnn_filter_size']
    attn_dcnn_depth = opt['attn_dcnn_depth']
    attn_dcnn_pool = opt['attn_dcnn_pool']
    attn_rnn_hid_dim = opt['attn_rnn_hid_dim']

    mlp_dropout_ratio = opt['mlp_dropout']

    num_attn_mlp_layers = opt['num_attn_mlp_layers']
    attn_mlp_depth = opt['attn_mlp_depth']
    attn_box_padding_ratio = opt['attn_box_padding_ratio']

    wd = opt['weight_decay']
    use_bn = opt['use_bn']
    use_gt_attn = opt['use_gt_attn']
    segm_loss_fn = opt['segm_loss_fn']
    box_loss_fn = opt['box_loss_fn']
    loss_mix_ratio = opt['loss_mix_ratio']
    base_learn_rate = opt['base_learn_rate']
    learn_rate_decay = opt['learn_rate_decay']
    steps_per_learn_rate_decay = opt['steps_per_learn_rate_decay']
    use_attn_rnn = opt['use_attn_rnn']
    use_knob = opt['use_knob']
    knob_base = opt['knob_base']
    knob_decay = opt['knob_decay']
    steps_per_knob_decay = opt['steps_per_knob_decay']
    use_canvas = opt['use_canvas']
    knob_box_offset = opt['knob_box_offset']
    knob_segm_offset = opt['knob_segm_offset']
    knob_use_timescale = opt['knob_use_timescale']
    gt_selector = opt['gt_selector']
    gt_box_ctr_noise = opt['gt_box_ctr_noise']
    gt_box_pad_noise = opt['gt_box_pad_noise']
    gt_segm_noise = opt['gt_segm_noise']
    downsample_canvas = opt['downsample_canvas']
    pretrain_cnn = opt['pretrain_cnn']
    cnn_share_weights = opt['cnn_share_weights']
    squash_ctrl_params = opt['squash_ctrl_params']
    use_iou_box = opt['use_iou_box']
    clip_gradient = opt['clip_gradient']

    rnd_hflip = opt['rnd_hflip']
    rnd_vflip = opt['rnd_vflip']
    rnd_transpose = opt['rnd_transpose']
    rnd_colour = opt['rnd_colour']

############################
# Input definition
############################
    with tf.device(get_device_fn(device)):
        # Input image, [B, H, W, D]
        x = tf.placeholder('float', [None, inp_height, inp_width, inp_depth])
        x_shape = tf.shape(x)
        num_ex = x_shape[0]

        # Groundtruth segmentation, [B, T, H, W]
        y_gt = tf.placeholder('float', [None, timespan, inp_height, inp_width])

        # Groundtruth confidence score, [B, T]
        s_gt = tf.placeholder('float', [None, timespan])

        # Whether in training stage.
        phase_train = tf.placeholder('bool')
        phase_train_f = tf.to_float(phase_train)
        model['x'] = x
        model['y_gt'] = y_gt
        model['s_gt'] = s_gt
        model['phase_train'] = phase_train

        # Global step
        global_step = tf.Variable(0.0)

###############################
# Random input transformation
###############################
        x, y_gt = img.random_transformation(
            x, y_gt, padding, phase_train,
            rnd_hflip=rnd_hflip, rnd_vflip=rnd_vflip,
            rnd_transpose=rnd_transpose, rnd_colour=rnd_colour)
        model['x_trans'] = x
        model['y_gt_trans'] = y_gt

############################
# Canvas: external memory
############################
        if use_canvas:
            canvas = tf.zeros(tf.pack([num_ex, inp_height, inp_width, 1]))
            ccnn_inp_depth = inp_depth + 1
            acnn_inp_depth = inp_depth + 1
        else:
            ccnn_inp_depth = inp_depth
            acnn_inp_depth = inp_depth

############################
# Controller CNN definition
############################
        ccnn_filters = ctrl_cnn_filter_size
        ccnn_nlayers = len(ccnn_filters)
        ccnn_channels = [ccnn_inp_depth] + ctrl_cnn_depth
        ccnn_pool = ctrl_cnn_pool
        ccnn_act = [tf.nn.relu] * ccnn_nlayers
        ccnn_use_bn = [use_bn] * ccnn_nlayers
        if pretrain_cnn:
            h5f = h5py.File(pretrain_cnn, 'r')
            ccnn_init_w = [{'w': h5f['cnn_w_{}'.format(ii)][:],
                            'b': h5f['cnn_b_{}'.format(ii)][:]}
                           for ii in xrange(ccnn_nlayers)]
            ccnn_frozen = True
        else:
            ccnn_init_w = None
            ccnn_frozen = None

        ccnn = nn.cnn(ccnn_filters, ccnn_channels, ccnn_pool, ccnn_act,
                      ccnn_use_bn, phase_train=phase_train, wd=wd,
                      scope='ctrl_cnn', model=model, init_weights=ccnn_init_w,
                      frozen=ccnn_frozen)
        h_ccnn = [None] * timespan

############################
# Controller RNN definition
############################
        ccnn_subsample = np.array(ccnn_pool).prod()
        crnn_h = inp_height / ccnn_subsample
        crnn_w = inp_width / ccnn_subsample
        crnn_dim = ctrl_rnn_hid_dim
        canvas_dim = inp_height * inp_width / (ccnn_subsample ** 2)
        glimpse_map_dim = crnn_h * crnn_w
        glimpse_feat_dim = ccnn_channels[-1]
        # crnn_inp_dim = crnn_h * crnn_w * ccnn_channels[-1]
        crnn_state = [None] * (timespan + 1)
        crnn_glimpse_map = [None] * timespan
        crnn_g_i = [None] * timespan
        crnn_g_f = [None] * timespan
        crnn_g_o = [None] * timespan
        h_crnn = [None] * timespan
        crnn_cell = nn.lstm(glimpse_feat_dim, crnn_dim,
                            wd=wd, scope='ctrl_lstm', model=model)

############################
# Glimpse MLP definition
############################
        gmlp_dims = [crnn_dim] * num_glimpse_mlp_layers + [glimpse_map_dim]
        gmlp_act = [tf.nn.relu] * \
            (num_glimpse_mlp_layers - 1) + [tf.nn.softmax]
        gmlp_dropout = None
        gmlp = nn.mlp(gmlp_dims, gmlp_act, add_bias=True,
                      dropout_keep=gmlp_dropout,
                      phase_train=phase_train, wd=wd, scope='glimpse_mlp',
                      model=model)

############################
# Controller MLP definition
############################
        cmlp_dims = [crnn_dim] + [ctrl_mlp_dim] * \
            (num_ctrl_mlp_layers - 1) + [9]
        cmlp_act = [tf.nn.relu] * (num_ctrl_mlp_layers - 1) + [None]
        cmlp_dropout = None
        # cmlp_dropout = [1.0 - mlp_dropout_ratio] * num_ctrl_mlp_layers
        cmlp = nn.mlp(cmlp_dims, cmlp_act, add_bias=True,
                      dropout_keep=cmlp_dropout,
                      phase_train=phase_train, wd=wd, scope='ctrl_mlp',
                      model=model)

############################
# Attention CNN definition
############################
        acnn_filters = attn_cnn_filter_size
        acnn_nlayers = len(acnn_filters)
        acnn_channels = [acnn_inp_depth] + attn_cnn_depth
        acnn_pool = attn_cnn_pool
        acnn_act = [tf.nn.relu] * acnn_nlayers
        acnn_use_bn = [use_bn] * acnn_nlayers
        if cnn_share_weights:
            ccnn_shared_weights = []
            for ii in xrange(ccnn_nlayers):
                ccnn_shared_weights.append(
                    {'w': model['ctrl_cnn_w_{}'.format(ii)],
                     'b': model['ctrl_cnn_b_{}'.format(ii)]})
        else:
            ccnn_shared_weights = None
        acnn = nn.cnn(acnn_filters, acnn_channels, acnn_pool, acnn_act,
                      acnn_use_bn, phase_train=phase_train, wd=wd,
                      scope='attn_cnn', model=model,
                      shared_weights=ccnn_shared_weights)

        x_patch = [None] * timespan
        h_acnn = [None] * timespan
        h_acnn_last = [None] * timespan

############################
# Attention RNN definition
############################
        acnn_subsample = np.array(acnn_pool).prod()
        arnn_h = filter_height / acnn_subsample
        arnn_w = filter_width / acnn_subsample

        if use_attn_rnn:
            arnn_dim = attn_rnn_hid_dim
            arnn_inp_dim = arnn_h * arnn_w * acnn_channels[-1]
            arnn_state = [None] * (timespan + 1)
            arnn_g_i = [None] * timespan
            arnn_g_f = [None] * timespan
            arnn_g_o = [None] * timespan
            arnn_state[-1] = tf.zeros(tf.pack([num_ex, arnn_dim * 2]))
            arnn_cell = nn.lstm(arnn_inp_dim, arnn_dim,
                                wd=wd, scope='attn_lstm')
            amlp_inp_dim = arnn_dim
        else:
            amlp_inp_dim = arnn_h * arnn_w * acnn_channels[-1]

############################
# Attention MLP definition
############################
        core_depth = attn_mlp_depth
        core_dim = arnn_h * arnn_w * core_depth
        amlp_dims = [amlp_inp_dim] + [core_dim] * num_attn_mlp_layers
        amlp_act = [tf.nn.relu] * num_attn_mlp_layers
        amlp_dropout = None
        # amlp_dropout = [1.0 - mlp_dropout_ratio] * num_attn_mlp_layers
        amlp = nn.mlp(amlp_dims, amlp_act, dropout_keep=amlp_dropout,
                      phase_train=phase_train, wd=wd, scope='attn_mlp',
                      model=model)

        # DCNN [B, RH, RW, MD] => [B, A, A, 1]
        adcnn_filters = attn_dcnn_filter_size
        adcnn_nlayers = len(adcnn_filters)
        adcnn_unpool = attn_dcnn_pool
        adcnn_act = [tf.nn.relu] * adcnn_nlayers
        adcnn_channels = [attn_mlp_depth] + attn_dcnn_depth
        adcnn_use_bn = [use_bn] * adcnn_nlayers
        adcnn_skip_ch = [0] + acnn_channels[::-1][1:]
        adcnn = nn.dcnn(adcnn_filters, adcnn_channels, adcnn_unpool,
                        adcnn_act, use_bn=adcnn_use_bn, skip_ch=adcnn_skip_ch,
                        phase_train=phase_train, wd=wd, model=model,
                        scope='attn_dcnn')
        h_adcnn = [None] * timespan

##########################
# Score MLP definition
##########################
        smlp = nn.mlp([crnn_dim, 1], [tf.sigmoid], wd=wd,
                      scope='score_mlp', model=model)
        s_out = [None] * timespan

##########################
# Attention box
##########################
        attn_ctr_norm = [None] * timespan
        attn_lg_size = [None] * timespan
        attn_ctr = [None] * timespan
        attn_size = [None] * timespan
        attn_lg_var = [None] * timespan
        attn_lg_gamma = [None] * timespan
        attn_gamma = [None] * timespan
        attn_box_lg_gamma = [None] * timespan
        attn_top_left = [None] * timespan
        attn_bot_right = [None] * timespan
        attn_box = [None] * timespan
        iou_soft_box = [None] * timespan
        const_ones = tf.ones(
            tf.pack([num_ex, filter_height, filter_width, 1]))
        attn_box_beta = tf.constant([-5.0])
        attn_box_gamma = [None] * timespan

#############################
# Groundtruth attention box
#############################
        # [B, T, 2]
        attn_ctr_gt, attn_size_gt, attn_lg_var_gt, attn_box_gt, \
            attn_top_left_gt, attn_bot_right_gt = \
            base.get_gt_attn(y_gt,
                             padding_ratio=attn_box_padding_ratio,
                             center_shift_ratio=0.0)
        attn_ctr_gt_noise, attn_size_gt_noise, attn_lg_var_gt_noise, \
            attn_box_gt_noise, \
            attn_top_left_gt_noise, attn_bot_right_gt_noise = \
            base.get_gt_attn(y_gt,
                             padding_ratio=tf.random_uniform(
                                 tf.pack([num_ex, timespan, 1]),
                                 attn_box_padding_ratio - gt_box_pad_noise,
                                 attn_box_padding_ratio + gt_box_pad_noise),
                             center_shift_ratio=tf.random_uniform(
                                 tf.pack([num_ex, timespan, 2]),
                                 -gt_box_ctr_noise, gt_box_ctr_noise))

##########################
# Groundtruth mix
##########################
        grd_match_cum = tf.zeros(tf.pack([num_ex, timespan]))
        # Add a bias on every entry so there is no duplicate match
        # [1, N]
        iou_bias_eps = 1e-7
        iou_bias = tf.expand_dims(tf.to_float(
            tf.reverse(tf.range(timespan), [True])) * iou_bias_eps, 0)

        # Scale mix ratio on different timesteps.
        gt_knob_time_scale = tf.reshape(
            1.0 + tf.log(1.0 + tf.to_float(tf.range(timespan)) * 3.0 *
                         float(knob_use_timescale)), [1, timespan, 1])

        # Mix in groundtruth box.
        global_step_box = tf.maximum(0.0, global_step - knob_box_offset)
        gt_knob_prob_box = tf.train.exponential_decay(
            knob_base, global_step_box, steps_per_knob_decay, knob_decay,
            staircase=False)
        gt_knob_prob_box = tf.minimum(
            1.0, gt_knob_prob_box * gt_knob_time_scale)
        gt_knob_box = tf.to_float(tf.random_uniform(
            tf.pack([num_ex, timespan, 1]), 0, 1.0) <= gt_knob_prob_box)
        model['gt_knob_prob_box'] = gt_knob_prob_box[0, 0, 0]

        # Mix in groundtruth segmentation.
        global_step_segm = tf.maximum(0.0, global_step - knob_segm_offset)
        gt_knob_prob_segm = tf.train.exponential_decay(
            knob_base, global_step_segm, steps_per_knob_decay, knob_decay,
            staircase=False)
        gt_knob_prob_segm = tf.minimum(
            1.0, gt_knob_prob_segm * gt_knob_time_scale)
        gt_knob_segm = tf.to_float(tf.random_uniform(
            tf.pack([num_ex, timespan, 1]), 0, 1.0) <= gt_knob_prob_segm)
        model['gt_knob_prob_segm'] = gt_knob_prob_segm[0, 0, 0]

##########################
# Segmentation output
##########################
        y_out = [None] * timespan
        y_out_lg_gamma = [None] * timespan
        y_out_beta = tf.constant([-5.0])

##########################
# Computation graph
##########################
        if not use_canvas:
            h_ccnn = ccnn(x)

        for tt in xrange(timespan):
            # Controller CNN [B, H, W, D] => [B, RH1, RW1, RD1]
            if use_canvas:
                ccnn_inp = tf.concat(3, [x, canvas])
                acnn_inp = ccnn_inp
                h_ccnn[tt] = ccnn(ccnn_inp)
                _h_ccnn = h_ccnn[tt]
            else:
                ccnn_inp = x
                acnn_inp = x
                _h_ccnn = h_ccnn

            h_ccnn_last = _h_ccnn[-1]
            # crnn_inp = tf.reshape(h_ccnn_last, [-1, crnn_inp_dim])
            crnn_inp = tf.reshape(
                h_ccnn_last, [-1, glimpse_map_dim, glimpse_feat_dim])

            crnn_state[tt] = [None] * (num_ctrl_rnn_iter + 1)
            crnn_g_i[tt] = [None] * num_ctrl_rnn_iter
            crnn_g_f[tt] = [None] * num_ctrl_rnn_iter
            crnn_g_o[tt] = [None] * num_ctrl_rnn_iter
            h_crnn[tt] = [None] * num_ctrl_rnn_iter

            crnn_state[tt][-1] = tf.zeros(tf.pack([num_ex, crnn_dim * 2]))
            # if tt == 0:
            #     crnn_state[tt][-1] = tf.zeros(tf.pack([num_ex, crnn_dim * 2]))
            # else:
            #     crnn_state[tt][-1] = crnn_state[tt - 1][num_ctrl_rnn_iter - 1]

            crnn_glimpse_map[tt] = [None] * num_ctrl_rnn_iter
            crnn_glimpse_map[tt][0] = tf.ones(
                tf.pack([num_ex, glimpse_map_dim, 1])) / glimpse_map_dim
            for tt2 in xrange(num_ctrl_rnn_iter):
                crnn_glimpse = tf.reduce_sum(
                    crnn_inp * crnn_glimpse_map[tt][tt2], [1])
                crnn_state[tt][tt2], crnn_g_i[tt][tt2], crnn_g_f[tt][tt2], \
                    crnn_g_o[tt][tt2] = \
                    crnn_cell(crnn_glimpse, crnn_state[tt][tt2 - 1])
                h_crnn[tt][tt2] = tf.slice(
                    crnn_state[tt][tt2], [0, crnn_dim], [-1, crnn_dim])
                h_gmlp = gmlp(h_crnn[tt][tt2])
                if tt2 < num_ctrl_rnn_iter - 1:
                    crnn_glimpse_map[tt][
                        tt2 + 1] = tf.expand_dims(h_gmlp[-1], 2)

            ctrl_out = cmlp(h_crnn[tt][-1])[-1]

            attn_ctr_norm[tt] = tf.slice(ctrl_out, [0, 0], [-1, 2])
            attn_lg_size[tt] = tf.slice(ctrl_out, [0, 2], [-1, 2])

            # Restrict to (-1, 1), (-inf, 0)
            if squash_ctrl_params:
                attn_ctr_norm[tt] = tf.tanh(attn_ctr_norm[tt])
                attn_lg_size[tt] = -tf.nn.softplus(attn_lg_size[tt])

            attn_ctr[tt], attn_size[tt] = base.get_unnormalized_attn(
                attn_ctr_norm[tt], attn_lg_size[tt], inp_height, inp_width)
            attn_lg_var[tt] = tf.zeros(tf.pack([num_ex, 2]))
            attn_lg_gamma[tt] = tf.slice(ctrl_out, [0, 6], [-1, 1])
            attn_box_lg_gamma[tt] = tf.slice(ctrl_out, [0, 7], [-1, 1])
            y_out_lg_gamma[tt] = tf.slice(ctrl_out, [0, 8], [-1, 1])

            attn_gamma[tt] = tf.reshape(
                tf.exp(attn_lg_gamma[tt]), [-1, 1, 1, 1])
            attn_box_gamma[tt] = tf.reshape(tf.exp(
                attn_box_lg_gamma[tt]), [-1, 1, 1, 1])
            y_out_lg_gamma[tt] = tf.reshape(y_out_lg_gamma[tt], [-1, 1, 1, 1])

            # Initial filters (predicted)
            filter_y = get_gaussian_filter(
                attn_ctr[tt][:, 0], attn_size[tt][:, 0],
                attn_lg_var[tt][:, 0], inp_height, filter_height)
            filter_x = get_gaussian_filter(
                attn_ctr[tt][:, 1], attn_size[tt][:, 1],
                attn_lg_var[tt][:, 1], inp_width, filter_width)
            filter_y_inv = tf.transpose(filter_y, [0, 2, 1])
            filter_x_inv = tf.transpose(filter_x, [0, 2, 1])

            # Attention box
            if use_iou_box:
                _idx_map = get_idx_map(
                    tf.pack([num_ex, inp_height, inp_width]))
                attn_top_left[tt], attn_bot_right[tt] = get_box_coord(
                    attn_ctr[tt], attn_size[tt])
                attn_box[tt] = get_filled_box_idx(
                    _idx_map, attn_top_left[tt], attn_bot_right[tt])
                attn_box[tt] = tf.reshape(attn_box[tt],
                                          [-1, 1, inp_height, inp_width])
            else:
                attn_box[tt] = extract_patch(const_ones * attn_box_gamma[tt],
                                             filter_y_inv, filter_x_inv, 1)
                attn_box[tt] = tf.sigmoid(attn_box[tt] + attn_box_beta)
                attn_box[tt] = tf.reshape(attn_box[tt],
                                          [-1, 1, inp_height, inp_width])

            # Here is the knob kick in GT bbox.
            if use_knob:
                # IOU [B, 1, T]
                # [B, 1, H, W] * [B, T, H, W] = [B, T]
                if use_iou_box:
                    _top_left = tf.expand_dims(attn_top_left[tt], 1)
                    _bot_right = tf.expand_dims(attn_bot_right[tt], 1)
                    iou_soft_box[tt] = f_iou_box(
                        _top_left, _bot_right, attn_top_left_gt,
                        attn_bot_right_gt)
                    iou_soft_box[tt] += iou_bias
                else:
                    iou_soft_box[tt] = f_inter(attn_box[tt], attn_box_gt) / \
                        f_union(attn_box[tt], attn_box_gt, eps=1e-5)

                grd_match = f_greedy_match(iou_soft_box[tt], grd_match_cum)

                if gt_selector == 'greedy_match':
                    # Add in the cumulative matching to not double count.
                    grd_match_cum += grd_match

                # [B, T, 1]
                grd_match = tf.expand_dims(grd_match, 2)
                attn_ctr_gt_match = tf.reduce_sum(
                    grd_match * attn_ctr_gt_noise, 1)
                attn_size_gt_match = tf.reduce_sum(
                    grd_match * attn_size_gt_noise, 1)

                _gt_knob_box = gt_knob_box
                attn_ctr[tt] = phase_train_f * _gt_knob_box[:, tt, 0: 1] * \
                    attn_ctr_gt_match + \
                    (1 - phase_train_f * _gt_knob_box[:, tt, 0: 1]) * \
                    attn_ctr[tt]
                attn_size[tt] = phase_train_f * _gt_knob_box[:, tt, 0: 1] * \
                    attn_size_gt_match + \
                    (1 - phase_train_f * _gt_knob_box[:, tt, 0: 1]) * \
                    attn_size[tt]

            attn_top_left[tt], attn_bot_right[tt] = get_box_coord(
                attn_ctr[tt], attn_size[tt])

            filter_y = get_gaussian_filter(
                attn_ctr[tt][:, 0], attn_size[tt][:, 0],
                attn_lg_var[tt][:, 0], inp_height, filter_height)
            filter_x = get_gaussian_filter(
                attn_ctr[tt][:, 1], attn_size[tt][:, 1],
                attn_lg_var[tt][:, 1], inp_width, filter_width)
            filter_y_inv = tf.transpose(filter_y, [0, 2, 1])
            filter_x_inv = tf.transpose(filter_x, [0, 2, 1])

            # Attended patch [B, A, A, D]
            x_patch[tt] = attn_gamma[tt] * extract_patch(
                acnn_inp, filter_y, filter_x, acnn_inp_depth)

            # CNN [B, A, A, D] => [B, RH2, RW2, RD2]
            h_acnn[tt] = acnn(x_patch[tt])
            h_acnn_last[tt] = h_acnn[tt][-1]

            if use_attn_rnn:
                # RNN [B, T, R2]
                arnn_inp = tf.reshape(h_acnn_last[tt], [-1, arnn_inp_dim])
                arnn_state[tt], arnn_g_i[tt], arnn_g_f[tt], arnn_g_o[tt] = \
                    arnn_cell(arnn_inp, arnn_state[tt - 1])

            # Scoring network
            s_out[tt] = smlp(h_crnn[tt][-1])[-1]

            # Dense segmentation network [B, R] => [B, M]
            if use_attn_rnn:
                h_arnn = tf.slice(
                    arnn_state[tt], [0, arnn_dim], [-1, arnn_dim])
                amlp_inp = h_arnn
            else:
                amlp_inp = h_acnn_last[tt]
            amlp_inp = tf.reshape(amlp_inp, [-1, amlp_inp_dim])
            h_core = amlp(amlp_inp)[-1]
            h_core = tf.reshape(h_core, [-1, arnn_h, arnn_w, attn_mlp_depth])

            # DCNN
            skip = [None] + h_acnn[tt][::-1][1:] + [x_patch[tt]]
            h_adcnn[tt] = adcnn(h_core, skip=skip)

            # Output
            y_out[tt] = extract_patch(
                h_adcnn[tt][-1], filter_y_inv, filter_x_inv, 1)
            y_out[tt] = tf.exp(y_out_lg_gamma[tt]) * y_out[tt] + y_out_beta
            y_out[tt] = tf.sigmoid(y_out[tt])
            y_out[tt] = tf.reshape(y_out[tt], [-1, 1, inp_height, inp_width])

            # Here is the knob kick in GT segmentations at this timestep.
            # [B, N, 1, 1]
            if use_canvas:
                if use_knob:
                    _gt_knob_segm = tf.expand_dims(
                        tf.expand_dims(gt_knob_segm[:, tt, 0: 1], 2), 3)
                    # [B, N, 1, 1]
                    grd_match = tf.expand_dims(grd_match, 3)
                    _y_out = tf.expand_dims(tf.reduce_sum(
                        grd_match * y_gt, 1), 3)
                    # Add independent uniform noise to groundtruth.
                    _noise = tf.random_uniform(
                        tf.pack([num_ex, inp_height, inp_width, 1]),
                        0, gt_segm_noise)
                    _y_out = _y_out - _y_out * _noise
                    _y_out = phase_train_f * _gt_knob_segm * _y_out + \
                        (1 - phase_train_f * _gt_knob_segm) * \
                        tf.reshape(y_out[tt], [-1, inp_height, inp_width, 1])
                else:
                    _y_out = tf.reshape(y_out[tt],
                                        [-1, inp_height, inp_width, 1])
                canvas += tf.stop_gradient(_y_out)

#########################
# Model outputs
#########################
        s_out = tf.concat(1, s_out)
        model['s_out'] = s_out
        y_out = tf.concat(1, y_out)
        model['y_out'] = y_out
        attn_box = tf.concat(1, attn_box)
        model['attn_box'] = attn_box
        x_patch = tf.concat(1, [tf.expand_dims(x_patch[tt], 1)
                                for tt in xrange(timespan)])
        model['x_patch'] = x_patch

#########################
# Loss function
#########################
        y_gt_shape = tf.shape(y_gt)
        num_ex_f = tf.to_float(y_gt_shape[0])
        max_num_obj = tf.to_float(y_gt_shape[1])

############################
# Box loss
############################
        if use_knob:
            iou_soft_box = tf.concat(1, [tf.expand_dims(iou_soft_box[tt], 1)
                                         for tt in xrange(timespan)])
        else:
            iou_soft_box = f_iou(attn_box, attn_box_gt,
                                 timespan, pairwise=True)

        model['iou_soft_box'] = iou_soft_box
        model['attn_box_gt'] = attn_box_gt
        match_box = f_segm_match(iou_soft_box, s_gt)
        model['match_box'] = match_box
        match_sum_box = tf.reduce_sum(match_box, reduction_indices=[2])
        match_count_box = tf.reduce_sum(
            match_sum_box, reduction_indices=[1])
        match_count_box = tf.maximum(1.0, match_count_box)
        iou_soft_box_mask = tf.reduce_sum(iou_soft_box * match_box, [1])
        iou_soft_box = tf.reduce_sum(tf.reduce_sum(iou_soft_box_mask, [1])
                                     / match_count_box) / num_ex_f
        gt_wt_box = f_coverage_weight(attn_box_gt)
        wt_iou_soft_box = tf.reduce_sum(tf.reduce_sum(
            iou_soft_box_mask * gt_wt_box, [1])
            / match_count_box) / num_ex_f
        if box_loss_fn == 'iou':
            box_loss = -iou_soft_box
        elif box_loss_fn == 'wt_iou':
            box_loss = -wt_iou_soft_box
        elif box_loss_fn == 'wt_cov':
            box_loss = -f_weighted_coverage(iou_soft_box, attn_box_gt)
        elif box_loss_fn == 'bce':
            box_loss = f_match_bce(attn_box, attn_box_gt, match_box, timespan)
        else:
            raise Exception('Unknown box_loss_fn: {}'.format(box_loss_fn))
        model['box_loss'] = box_loss

        box_loss_coeff = tf.constant(1.0)
        tf.add_to_collection('losses', box_loss_coeff * box_loss)

##############################
# Segmentation loss
##############################
        # IoU (soft)
        iou_soft = f_iou(y_out, y_gt, timespan, pairwise=True)
        match = f_segm_match(iou_soft, s_gt)
        model['match'] = match
        match_sum = tf.reduce_sum(match, reduction_indices=[2])
        match_count = tf.reduce_sum(match_sum, reduction_indices=[1])
        match_count = tf.maximum(1.0, match_count)

        # Weighted coverage (soft)
        wt_cov_soft = f_weighted_coverage(iou_soft, y_gt)
        model['wt_cov_soft'] = wt_cov_soft
        unwt_cov_soft = f_unweighted_coverage(iou_soft, match_count)
        model['unwt_cov_soft'] = unwt_cov_soft

        # IOU (soft)
        iou_soft_mask = tf.reduce_sum(iou_soft * match, [1])
        iou_soft = tf.reduce_sum(tf.reduce_sum(iou_soft_mask, [1]) /
                                 match_count) / num_ex_f
        model['iou_soft'] = iou_soft
        gt_wt = f_coverage_weight(y_gt)
        wt_iou_soft = tf.reduce_sum(tf.reduce_sum(iou_soft_mask * gt_wt, [1]) /
                                    match_count) / num_ex_f
        model['wt_iou_soft'] = wt_iou_soft

        if segm_loss_fn == 'iou':
            segm_loss = -iou_soft
        elif segm_loss_fn == 'wt_iou':
            segm_loss = -wt_iou_soft
        elif segm_loss_fn == 'wt_cov':
            segm_loss = -wt_cov_soft
        elif segm_loss_fn == 'bce':
            segm_loss = f_match_bce(y_out, y_gt, match, timespan)
        else:
            raise Exception('Unknown segm_loss_fn: {}'.format(segm_loss_fn))
        model['segm_loss'] = segm_loss
        segm_loss_coeff = tf.constant(1.0)
        tf.add_to_collection('losses', segm_loss_coeff * segm_loss)

####################
# Score loss
####################
        conf_loss = f_conf_loss(s_out, match, timespan, use_cum_min=True)
        model['conf_loss'] = conf_loss
        tf.add_to_collection('losses', loss_mix_ratio * conf_loss)

####################
# Total loss
####################
        total_loss = tf.add_n(tf.get_collection(
            'losses'), name='total_loss')
        model['loss'] = total_loss

####################
# Optimizer
####################
        learn_rate = tf.train.exponential_decay(
            base_learn_rate, global_step, steps_per_learn_rate_decay,
            learn_rate_decay, staircase=True)
        model['learn_rate'] = learn_rate
        eps = 1e-7

        train_step = GradientClipOptimizer(
            tf.train.AdamOptimizer(learn_rate, epsilon=eps),
            clip=clip_gradient).minimize(total_loss, global_step=global_step)
        model['train_step'] = train_step

####################
# Statistics
####################
        # [B, M, N] * [B, M, N] => [B] * [B] => [1]
        y_out_hard = tf.to_float(y_out > 0.5)
        iou_hard = f_iou(y_out_hard, y_gt, timespan, pairwise=True)
        wt_cov_hard = f_weighted_coverage(iou_hard, y_gt)
        model['wt_cov_hard'] = wt_cov_hard
        unwt_cov_hard = f_unweighted_coverage(iou_hard, match_count)
        model['unwt_cov_hard'] = unwt_cov_hard
        # [B, T]
        iou_hard_mask = tf.reduce_sum(iou_hard * match, [1])
        iou_hard = tf.reduce_sum(tf.reduce_sum(iou_hard_mask, [1]) /
                                 match_count) / num_ex_f
        model['iou_hard'] = iou_hard
        wt_iou_hard = tf.reduce_sum(tf.reduce_sum(iou_hard_mask * gt_wt, [1]) /
                                    match_count) / num_ex_f
        model['wt_iou_hard'] = wt_iou_hard

        dice = f_dice(y_out_hard, y_gt, timespan, pairwise=True)
        dice = tf.reduce_sum(tf.reduce_sum(dice * match, [1, 2]) /
                             match_count) / num_ex_f
        model['dice'] = dice

        model['count_acc'] = f_count_acc(s_out, s_gt)
        model['dic'] = f_dic(s_out, s_gt, abs=False)
        model['dic_abs'] = f_dic(s_out, s_gt, abs=True)

################################
# Controller output statistics
################################
        attn_top_left = tf.concat(1, [tf.expand_dims(tmp, 1)
                                      for tmp in attn_top_left])
        attn_bot_right = tf.concat(1, [tf.expand_dims(tmp, 1)
                                       for tmp in attn_bot_right])
        attn_ctr = tf.concat(1, [tf.expand_dims(tmp, 1)
                                 for tmp in attn_ctr])
        attn_size = tf.concat(1, [tf.expand_dims(tmp, 1)
                                  for tmp in attn_size])
        attn_lg_gamma = tf.concat(1, [tf.expand_dims(tmp, 1)
                                      for tmp in attn_lg_gamma])
        attn_box_lg_gamma = tf.concat(1, [tf.expand_dims(tmp, 1)
                                          for tmp in attn_box_lg_gamma])
        y_out_lg_gamma = tf.concat(1, [tf.expand_dims(tmp, 1)
                                       for tmp in y_out_lg_gamma])
        attn_lg_gamma_mean = tf.reduce_sum(attn_lg_gamma) / num_ex_f / timespan
        attn_box_lg_gamma_mean = tf.reduce_sum(
            attn_box_lg_gamma) / num_ex_f / timespan
        y_out_lg_gamma_mean = tf.reduce_sum(
            y_out_lg_gamma) / num_ex_f / timespan
        model['attn_ctr'] = attn_ctr
        model['attn_size'] = attn_size
        model['attn_top_left'] = attn_top_left
        model['attn_bot_right'] = attn_bot_right
        model['attn_lg_gamma_mean'] = attn_lg_gamma_mean
        model['attn_box_lg_gamma_mean'] = attn_box_lg_gamma_mean
        model['y_out_lg_gamma_mean'] = y_out_lg_gamma_mean

    return model
Beispiel #8
0
def get_model(opt, device='/cpu:0'):
    """The attention model"""
    model = {}

    timespan = opt['timespan']
    inp_height = opt['inp_height']
    inp_width = opt['inp_width']
    inp_depth = opt['inp_depth']
    padding = opt['padding']
    filter_height = opt['filter_height']
    filter_width = opt['filter_width']

    attn_cnn_filter_size = opt['attn_cnn_filter_size']
    attn_cnn_depth = opt['attn_cnn_depth']
    attn_cnn_pool = opt['attn_cnn_pool']
    attn_dcnn_filter_size = opt['attn_dcnn_filter_size']
    attn_dcnn_depth = opt['attn_dcnn_depth']
    attn_dcnn_pool = opt['attn_dcnn_pool']

    mlp_dropout_ratio = opt['mlp_dropout']

    num_attn_mlp_layers = opt['num_attn_mlp_layers']
    attn_mlp_depth = opt['attn_mlp_depth']
    attn_box_padding_ratio = opt['attn_box_padding_ratio']

    wd = opt['weight_decay']
    use_bn = opt['use_bn']
    segm_loss_fn = opt['segm_loss_fn']
    base_learn_rate = opt['base_learn_rate']
    learn_rate_decay = opt['learn_rate_decay']
    steps_per_learn_rate_decay = opt['steps_per_learn_rate_decay']
    gt_box_ctr_noise = opt['gt_box_ctr_noise']
    gt_box_pad_noise = opt['gt_box_pad_noise']
    gt_segm_noise = opt['gt_segm_noise']
    clip_gradient = opt['clip_gradient']
    fixed_order = opt['fixed_order']

    rnd_hflip = opt['rnd_hflip']
    rnd_vflip = opt['rnd_vflip']
    rnd_transpose = opt['rnd_transpose']
    rnd_colour = opt['rnd_colour']

############################
# Input definition
############################
    with tf.device(base.get_device_fn(device)):
        # Input image, [B, H, W, D]
        x = tf.placeholder('float', [None, inp_height, inp_width, inp_depth],
                           name='x')
        x_shape = tf.shape(x)
        num_ex = x_shape[0]

        # Groundtruth segmentation, [B, T, H, W]
        y_gt = tf.placeholder('float', [None, timespan, inp_height, inp_width],
                              name='y_gt')

        # Groundtruth confidence score, [B, T]
        s_gt = tf.placeholder('float', [None, timespan], name='s_gt')

        # Order in which we feed in the samples
        order = tf.placeholder('int32', [None, timespan], name='order')

        # Whether in training stage.
        phase_train = tf.placeholder('bool', name='phase_train')
        phase_train_f = tf.to_float(phase_train)

        model['x'] = x
        model['y_gt'] = y_gt
        model['s_gt'] = s_gt
        model['phase_train'] = phase_train
        model['order'] = order

        # Global step
        global_step = tf.Variable(0.0, name='global_step')

###############################
# Random input transformation
###############################
        x, y_gt = img.random_transformation(
            x, y_gt, padding, phase_train,
            rnd_hflip=rnd_hflip, rnd_vflip=rnd_vflip,
            rnd_transpose=rnd_transpose, rnd_colour=rnd_colour)
        model['x_trans'] = x
        model['y_gt_trans'] = y_gt

############################
# Canvas: external memory
############################
        canvas = tf.zeros(tf.pack([num_ex, inp_height, inp_width, 1]))
        acnn_inp_depth = inp_depth + 1

###########################
# Attention CNN definition
###########################
        acnn_filters = attn_cnn_filter_size
        acnn_nlayers = len(acnn_filters)
        acnn_channels = [acnn_inp_depth] + attn_cnn_depth
        acnn_pool = attn_cnn_pool
        acnn_act = [tf.nn.relu] * acnn_nlayers
        acnn_use_bn = [use_bn] * acnn_nlayers
        acnn = nn.cnn(acnn_filters, acnn_channels, acnn_pool, acnn_act,
                      acnn_use_bn, phase_train=phase_train, wd=wd,
                      scope='attn_cnn', model=model)

        x_patch = [None] * timespan
        h_acnn = [None] * timespan
        h_acnn_last = [None] * timespan

############################
# Attention MLP definition
############################
        acnn_subsample = np.array(acnn_pool).prod()
        acnn_h = filter_height / acnn_subsample
        acnn_w = filter_width / acnn_subsample
        core_depth = attn_mlp_depth
        core_dim = acnn_h * acnn_w * core_depth
        amlp_inp_dim = acnn_h * acnn_w * acnn_channels[-1]
        amlp_dims = [amlp_inp_dim] + [core_dim] * num_attn_mlp_layers
        amlp_act = [tf.nn.relu] * num_attn_mlp_layers
        amlp_dropout = None
        # amlp_dropout = [1.0 - mlp_dropout_ratio] * num_attn_mlp_layers
        amlp = nn.mlp(amlp_dims, amlp_act, dropout_keep=amlp_dropout,
                      phase_train=phase_train, wd=wd, scope='attn_mlp',
                      model=model)

#############################
# Attention DCNN definition
#############################
        adcnn_filters = attn_dcnn_filter_size
        adcnn_nlayers = len(adcnn_filters)
        adcnn_unpool = attn_dcnn_pool
        adcnn_act = [tf.nn.relu] * adcnn_nlayers
        adcnn_channels = [attn_mlp_depth] + attn_dcnn_depth
        adcnn_bn_nlayers = adcnn_nlayers
        # adcnn_bn_nlayers = adcnn_nlayers - 1
        adcnn_use_bn = [use_bn] * adcnn_bn_nlayers + \
            [False] * (adcnn_nlayers - adcnn_bn_nlayers)
        adcnn_skip_ch = [0] + acnn_channels[::-1][1:]
        adcnn = nn.dcnn(adcnn_filters, adcnn_channels, adcnn_unpool,
                        adcnn_act, use_bn=adcnn_use_bn, skip_ch=adcnn_skip_ch,
                        phase_train=phase_train, wd=wd, model=model,
                        scope='attn_dcnn')
        h_adcnn = [None] * timespan

##########################
# Attention box
##########################
        attn_ctr = [None] * timespan
        attn_size = [None] * timespan
        attn_top_left = [None] * timespan
        attn_bot_right = [None] * timespan
        attn_box = [None] * timespan

#############################
# Groundtruth attention box
#############################
        attn_ctr_gt_noise, attn_size_gt_noise, attn_lg_var_gt_noise, \
            attn_box_gt_noise, \
            attn_top_left_gt_noise, attn_bot_right_gt_noise = \
            base.get_gt_attn(y_gt,
                             padding_ratio=tf.random_uniform(
                                 tf.pack([num_ex, timespan, 1]),
                                 attn_box_padding_ratio - gt_box_pad_noise,
                                 attn_box_padding_ratio + gt_box_pad_noise),
                             center_shift_ratio=tf.random_uniform(
                                 tf.pack([num_ex, timespan, 2]),
                                 -gt_box_ctr_noise, gt_box_ctr_noise),
                             min_padding=25.0)

        # Attention CNN definition


##########################
# Segmentation output
##########################
        y_out = [None] * timespan
        y_out_lg_gamma = tf.constant([2.0])
        y_out_beta = tf.constant([-5.0])

##########################
# Computation graph
##########################
        for tt in xrange(timespan):
            # Get a new greedy match based on order.
            if fixed_order:
                attn_ctr[tt] = attn_ctr_gt_noise[:, tt, :]
                attn_size[tt] = attn_size_gt_noise[:, tt, :]
            else:
                mask = _get_idx_mask(order[:, tt], timespan)
                # [B, T, 1]
                mask = tf.expand_dims(mask, 2)
                attn_ctr[tt] = tf.reduce_sum(mask * attn_ctr_gt_noise, 1)
                attn_size[tt] = tf.reduce_sum(mask * attn_size_gt_noise, 1)

            attn_top_left[tt], attn_bot_right[tt] = base.get_box_coord(
                attn_ctr[tt], attn_size[tt])

            # [B, H, H']
            filter_y = base.get_gaussian_filter(
                attn_ctr[tt][:, 0], attn_size[tt][:, 0], 0.0,
                inp_height, filter_height)

            # [B, W, W']
            filter_x = base.get_gaussian_filter(
                attn_ctr[tt][:, 1], attn_size[tt][:, 1], 0.0,
                inp_width, filter_width)

            # [B, H', H]
            filter_y_inv = tf.transpose(filter_y, [0, 2, 1])

            # [B, W', W]
            filter_x_inv = tf.transpose(filter_x, [0, 2, 1])

            # Attended patch [B, A, A, D]
            acnn_inp = tf.concat(3, [x, canvas])
            x_patch[tt] = base.extract_patch(
                acnn_inp, filter_y, filter_x, acnn_inp_depth)

            # CNN [B, A, A, D] => [B, RH2, RW2, RD2]
            h_acnn[tt] = acnn(x_patch[tt])
            h_acnn_last[tt] = h_acnn[tt][-1]
            amlp_inp = h_acnn_last[tt]
            amlp_inp = tf.reshape(amlp_inp, [-1, amlp_inp_dim])
            h_core = amlp(amlp_inp)[-1]
            h_core = tf.reshape(h_core, [-1, acnn_h, acnn_w, attn_mlp_depth])

            # DCNN
            skip = [None] + h_acnn[tt][::-1][1:] + [x_patch[tt]]
            h_adcnn[tt] = adcnn(h_core, skip=skip)

            # Output
            y_out[tt] = base.extract_patch(
                h_adcnn[tt][-1], filter_y_inv, filter_x_inv, 1)
            y_out[tt] = tf.exp(y_out_lg_gamma) * y_out[tt] + y_out_beta
            y_out[tt] = tf.sigmoid(y_out[tt])
            y_out[tt] = tf.reshape(y_out[tt], [-1, 1, inp_height, inp_width])

            # Canvas
            if fixed_order:
                _y_out = y_gt[:, tt, :, :]
            else:
                mask = tf.expand_dims(mask, 3)
                _y_out = tf.reduce_sum(mask * y_gt, 1)

            _y_out = tf.expand_dims(_y_out, 3)
            # Add independent uniform noise to groundtruth.
            _noise = tf.random_uniform(
                tf.pack([num_ex, inp_height, inp_width, 1]),
                0, gt_segm_noise)
            _y_out = _y_out - _y_out * _noise
            canvas += _y_out

#########################
# Model outputs
#########################
        y_out = tf.concat(1, y_out)
        model['y_out'] = y_out
        x_patch = tf.concat(1, [tf.expand_dims(x_patch[tt], 1)
                                for tt in xrange(timespan)])
        model['x_patch'] = x_patch

        attn_top_left = tf.concat(1, [tf.expand_dims(tmp, 1)
                                      for tmp in attn_top_left])
        attn_bot_right = tf.concat(1, [tf.expand_dims(tmp, 1)
                                       for tmp in attn_bot_right])
        attn_ctr = tf.concat(1, [tf.expand_dims(tmp, 1)
                                 for tmp in attn_ctr])
        attn_size = tf.concat(1, [tf.expand_dims(tmp, 1)
                                  for tmp in attn_size])
        model['attn_ctr'] = attn_ctr
        model['attn_size'] = attn_size
        model['attn_top_left'] = attn_top_left
        model['attn_bot_right'] = attn_bot_right

#########################
# Loss function
#########################
        y_gt_shape = tf.shape(y_gt)
        num_ex_f = tf.to_float(y_gt_shape[0])
        max_num_obj = tf.to_float(y_gt_shape[1])

##############################
# Segmentation loss
##############################
        # Matching
        identity_match = base.get_identity_match(num_ex, timespan, s_gt)
        iou_soft_pairwise = base.f_iou(y_out, y_gt, timespan, pairwise=True)
        real_match = base.f_segm_match(iou_soft_pairwise, s_gt)
        if fixed_order:
            iou_soft = base.f_iou(y_out, y_gt, pairwise=False)
            match = identity_match
        else:
            iou_soft = iou_soft_pairwise
            match = real_match
        model['match'] = match
        match_sum = tf.reduce_sum(match, reduction_indices=[2])
        match_count = tf.reduce_sum(match_sum, reduction_indices=[1])
        match_count = tf.maximum(1.0, match_count)

        # Weighted coverage (soft)
        wt_cov_soft = base.f_weighted_coverage(iou_soft_pairwise, y_gt)
        model['wt_cov_soft'] = wt_cov_soft
        unwt_cov_soft = base.f_unweighted_coverage(
            iou_soft_pairwise, match_count)
        model['unwt_cov_soft'] = unwt_cov_soft

        # IOU (soft)
        if fixed_order:
            iou_soft_mask = iou_soft
        else:
            iou_soft_mask = tf.reduce_sum(iou_soft * match, [1])
        iou_soft = tf.reduce_sum(iou_soft_mask, [1])
        iou_soft = tf.reduce_sum(iou_soft / match_count) / num_ex_f
        model['iou_soft'] = iou_soft

        if segm_loss_fn == 'iou':
            segm_loss = -iou_soft
        elif segm_loss_fn == 'wt_iou':
            segm_loss = -wt_iou_soft
        elif segm_loss_fn == 'wt_cov':
            segm_loss = -wt_cov_soft
        elif segm_loss_fn == 'bce':
            segm_loss = base.f_match_bce(y_out, y_gt, match, timespan)
        else:
            raise Exception('Unknown segm_loss_fn: {}'.format(segm_loss_fn))
        model['segm_loss'] = segm_loss
        segm_loss_coeff = tf.constant(1.0)
        tf.add_to_collection('losses', segm_loss_coeff * segm_loss)

####################
# Total loss
####################
        total_loss = tf.add_n(tf.get_collection('losses'), name='total_loss')
        model['loss'] = total_loss

####################
# Optimizer
####################
        learn_rate = tf.train.exponential_decay(
            base_learn_rate, global_step, steps_per_learn_rate_decay,
            learn_rate_decay, staircase=True)
        model['learn_rate'] = learn_rate
        eps = 1e-7

        train_step = GradientClipOptimizer(
            tf.train.AdamOptimizer(learn_rate, epsilon=eps),
            clip=clip_gradient).minimize(total_loss, global_step=global_step)
        model['train_step'] = train_step

####################
# Statistics
####################
        # Here statistics (hard measures) is always using matching.
        y_out_hard = tf.to_float(y_out > 0.5)
        iou_hard = base.f_iou(y_out_hard, y_gt, timespan, pairwise=True)
        wt_cov_hard = base.f_weighted_coverage(iou_hard, y_gt)
        model['wt_cov_hard'] = wt_cov_hard
        unwt_cov_hard = base.f_unweighted_coverage(iou_hard, match_count)
        model['unwt_cov_hard'] = unwt_cov_hard
        iou_hard_mask = tf.reduce_sum(iou_hard * real_match, [1])
        iou_hard = tf.reduce_sum(tf.reduce_sum(iou_hard_mask, [1]) /
                                 match_count) / num_ex_f
        model['iou_hard'] = iou_hard

        dice = base.f_dice(y_out_hard, y_gt, timespan, pairwise=True)
        dice = tf.reduce_sum(tf.reduce_sum(
            dice * real_match, reduction_indices=[1, 2]) / match_count) / \
            num_ex_f
        model['dice'] = dice

    return model
Beispiel #9
0
def get_model(opt, device='/cpu:0'):
    """The box model"""
    model = {}

    timespan = opt['timespan']
    inp_height = opt['inp_height']
    inp_width = opt['inp_width']
    inp_depth = opt['inp_depth']
    padding = opt['padding']
    filter_height = opt['filter_height']
    filter_width = opt['filter_width']

    ctrl_cnn_filter_size = opt['ctrl_cnn_filter_size']
    ctrl_cnn_depth = opt['ctrl_cnn_depth']
    ctrl_cnn_pool = opt['ctrl_cnn_pool']
    ctrl_rnn_hid_dim = opt['ctrl_rnn_hid_dim']

    num_ctrl_mlp_layers = opt['num_ctrl_mlp_layers']
    ctrl_mlp_dim = opt['ctrl_mlp_dim']

    attn_box_padding_ratio = opt['attn_box_padding_ratio']

    wd = opt['weight_decay']
    use_bn = opt['use_bn']
    box_loss_fn = opt['box_loss_fn']
    base_learn_rate = opt['base_learn_rate']
    learn_rate_decay = opt['learn_rate_decay']
    steps_per_learn_rate_decay = opt['steps_per_learn_rate_decay']
    pretrain_cnn = opt['pretrain_cnn']
    squash_ctrl_params = opt['squash_ctrl_params']
    clip_gradient = opt['clip_gradient']
    fixed_order = opt['fixed_order']
    ctrl_rnn_inp_struct = opt['ctrl_rnn_inp_struct']  # dense or attn
    num_ctrl_rnn_iter = opt['num_ctrl_rnn_iter']
    num_glimpse_mlp_layers = opt['num_glimpse_mlp_layers']
    
    rnd_hflip = opt['rnd_hflip']
    rnd_vflip = opt['rnd_vflip']
    rnd_transpose = opt['rnd_transpose']
    rnd_colour = opt['rnd_colour']

############################
# Input definition
############################
    with tf.device(base.get_device_fn(device)):
        # Input image, [B, H, W, D]
        x = tf.placeholder('float', [None, inp_height, inp_width, inp_depth],
                           name='x')
        x_shape = tf.shape(x)
        num_ex = x_shape[0]

        # Groundtruth segmentation, [B, T, H, W]
        y_gt = tf.placeholder('float', [None, timespan, inp_height, inp_width],
                              name='y_gt')

        # Groundtruth confidence score, [B, T]
        s_gt = tf.placeholder('float', [None, timespan], name='s_gt')

        # Whether in training stage.
        phase_train = tf.placeholder('bool', name='phase_train')
        phase_train_f = tf.to_float(phase_train)
        model['x'] = x
        model['y_gt'] = y_gt
        model['s_gt'] = s_gt
        model['phase_train'] = phase_train

        # Global step
        global_step = tf.Variable(0.0, name='global_step')

###############################
# Random input transformation
###############################
        x, y_gt = img.random_transformation(
            x, y_gt, padding, phase_train,
            rnd_hflip=rnd_hflip, rnd_vflip=rnd_vflip,
            rnd_transpose=rnd_transpose, rnd_colour=rnd_colour)
        model['x_trans'] = x
        model['y_gt_trans'] = y_gt

############################
# Canvas: external memory
############################
        canvas = tf.zeros(tf.pack([num_ex, inp_height, inp_width, 1]))
        ccnn_inp_depth = inp_depth + 1
        acnn_inp_depth = inp_depth + 1

############################
# Controller CNN definition
############################
        ccnn_filters = ctrl_cnn_filter_size
        ccnn_nlayers = len(ccnn_filters)
        ccnn_channels = [ccnn_inp_depth] + ctrl_cnn_depth
        ccnn_pool = ctrl_cnn_pool
        ccnn_act = [tf.nn.relu] * ccnn_nlayers
        ccnn_use_bn = [use_bn] * ccnn_nlayers

        if pretrain_cnn:
            log.info('Loading pretrained weights from {}'.format(pretrain_cnn))
            h5f = h5py.File(pretrain_cnn, 'r')
            acnn_nlayers = 0
            # Assuming acnn_nlayers is smaller than ccnn_nlayers.
            for ii in xrange(ccnn_nlayers):
                if 'attn_cnn_w_{}'.format(ii) in h5f:
                    log.info('Loading attn_cnn_w_{}'.format(ii))
                    log.info('Loading attn_cnn_b_{}'.format(ii))
                    acnn_nlayers += 1
            ccnn_init_w = [{'w': h5f['attn_cnn_w_{}'.format(ii)][:],
                            'b': h5f['attn_cnn_b_{}'.format(ii)][:]}
                           for ii in xrange(acnn_nlayers)]
            for ii in xrange(acnn_nlayers):
                for tt in xrange(timespan):
                    for w in ['beta', 'gamma']:
                        ccnn_init_w[ii]['{}_{}'.format(w, tt)] = h5f[
                            'attn_cnn_{}_{}_{}'.format(ii, tt, w)][:]
            ccnn_frozen = [True] * acnn_nlayers
            for ii in xrange(acnn_nlayers, ccnn_nlayers):
                ccnn_init_w.append(None)
                ccnn_frozen.append(False)
        else:
            ccnn_init_w = None
            ccnn_frozen = None

        ccnn = nn.cnn(ccnn_filters, ccnn_channels, ccnn_pool, ccnn_act,
                      ccnn_use_bn, phase_train=phase_train, wd=wd,
                      scope='ctrl_cnn', model=model, init_weights=ccnn_init_w,
                      frozen=ccnn_frozen)
        h_ccnn = [None] * timespan

############################
# Controller RNN definition
############################
        ccnn_subsample = np.array(ccnn_pool).prod()
        crnn_h = inp_height / ccnn_subsample
        crnn_w = inp_width / ccnn_subsample
        crnn_dim = ctrl_rnn_hid_dim
        canvas_dim = inp_height * inp_width / (ccnn_subsample ** 2)

        glimpse_map_dim = crnn_h * crnn_w
        glimpse_feat_dim = ccnn_channels[-1]
        if ctrl_rnn_inp_struct == 'dense':
            crnn_inp_dim = crnn_h * crnn_w * ccnn_channels[-1]
        elif ctrl_rnn_inp_struct == 'attn':
            crnn_inp_dim = glimpse_feat_dim

        crnn_state = [None] * (timespan + 1)
        crnn_glimpse_map = [None] * timespan
        crnn_g_i = [None] * timespan
        crnn_g_f = [None] * timespan
        crnn_g_o = [None] * timespan
        h_crnn = [None] * timespan
        crnn_state[-1] = tf.zeros(tf.pack([num_ex, crnn_dim * 2]))
        crnn_cell = nn.lstm(crnn_inp_dim, crnn_dim, wd=wd, scope='ctrl_lstm',
                            model=model)

############################
# Glimpse MLP definition
############################
        gmlp_dims = [crnn_dim] * num_glimpse_mlp_layers + [glimpse_map_dim]
        gmlp_act = [tf.nn.relu] * \
            (num_glimpse_mlp_layers - 1) + [tf.nn.softmax]
        gmlp_dropout = None
        gmlp = nn.mlp(gmlp_dims, gmlp_act, add_bias=True,
                      dropout_keep=gmlp_dropout,
                      phase_train=phase_train, wd=wd, scope='glimpse_mlp',
                      model=model)

############################
# Controller MLP definition
############################
        cmlp_dims = [crnn_dim] + [ctrl_mlp_dim] * \
            (num_ctrl_mlp_layers - 1) + [9]
        cmlp_act = [tf.nn.relu] * (num_ctrl_mlp_layers - 1) + [None]
        cmlp_dropout = None
        cmlp = nn.mlp(cmlp_dims, cmlp_act, add_bias=True,
                      dropout_keep=cmlp_dropout,
                      phase_train=phase_train, wd=wd, scope='ctrl_mlp',
                      model=model)

##########################
# Score MLP definition
##########################
        smlp = nn.mlp([crnn_dim, 1], [tf.sigmoid], wd=wd, scope='score_mlp')
        s_out = [None] * timespan

##########################
# Attention box
##########################
        attn_ctr_norm = [None] * timespan
        attn_lg_size = [None] * timespan
        attn_lg_var = [None] * timespan
        attn_ctr = [None] * timespan
        attn_size = [None] * timespan
        attn_top_left = [None] * timespan
        attn_bot_right = [None] * timespan
        attn_box = [None] * timespan
        attn_box_lg_gamma = [None] * timespan
        attn_box_gamma = [None] * timespan
        const_ones = tf.ones(tf.pack([num_ex, filter_height, filter_width, 1]))
        attn_box_beta = tf.constant([-5.0])
        iou_soft_box = [None] * timespan

#############################
# Groundtruth attention box
#############################
        attn_top_left_gt, attn_bot_right_gt, attn_box_gt = base.get_gt_box(
            y_gt, padding_ratio=attn_box_padding_ratio, center_shift_ratio=0.0)
        attn_ctr_gt, attn_size_gt = base.get_box_ctr_size(
            attn_top_left_gt, attn_bot_right_gt)
        attn_ctr_norm_gt = base.get_normalized_center(
            attn_ctr_gt, inp_height, inp_width)
        attn_lg_size_gt = base.get_normalized_size(
            attn_size_gt, inp_height, inp_width)

##########################
# Groundtruth mix
##########################
        grd_match_cum = tf.zeros(tf.pack([num_ex, timespan]))

##########################
# Computation graph
##########################
        for tt in xrange(timespan):
            # Controller CNN
            ccnn_inp = tf.concat(3, [x, canvas])
            acnn_inp = ccnn_inp
            h_ccnn[tt] = ccnn(ccnn_inp)
            _h_ccnn = h_ccnn[tt]
            h_ccnn_last = _h_ccnn[-1]

            # Controller RNN [B, R1]
            if ctrl_rnn_inp_struct == 'dense':
                crnn_inp = tf.reshape(h_ccnn_last, [-1, crnn_inp_dim])
                crnn_state[tt], crnn_g_i[tt], crnn_g_f[tt], crnn_g_o[tt] = \
                    crnn_cell(crnn_inp, crnn_state[tt - 1])
                h_crnn[tt] = tf.slice(
                    crnn_state[tt], [0, crnn_dim], [-1, crnn_dim])

                ctrl_out = cmlp(h_crnn[tt])[-1]

            elif ctrl_rnn_inp_struct == 'attn':
                crnn_inp = tf.reshape(
                    h_ccnn_last, [-1, glimpse_map_dim, glimpse_feat_dim])
                crnn_state[tt] = [None] * (num_ctrl_rnn_iter + 1)
                crnn_g_i[tt] = [None] * num_ctrl_rnn_iter
                crnn_g_f[tt] = [None] * num_ctrl_rnn_iter
                crnn_g_o[tt] = [None] * num_ctrl_rnn_iter
                h_crnn[tt] = [None] * num_ctrl_rnn_iter

                crnn_state[tt][-1] = tf.zeros(tf.pack([num_ex, crnn_dim * 2]))

                crnn_glimpse_map[tt] = [None] * num_ctrl_rnn_iter
                crnn_glimpse_map[tt][0] = tf.ones(
                    tf.pack([num_ex, glimpse_map_dim, 1])) / glimpse_map_dim

                # Inner glimpse RNN
                for tt2 in xrange(num_ctrl_rnn_iter):
                    crnn_glimpse = tf.reduce_sum(
                        crnn_inp * crnn_glimpse_map[tt][tt2], [1])
                    crnn_state[tt][tt2], crnn_g_i[tt][tt2], crnn_g_f[tt][tt2], \
                        crnn_g_o[tt][tt2] = \
                        crnn_cell(crnn_glimpse, crnn_state[tt][tt2 - 1])
                    h_crnn[tt][tt2] = tf.slice(
                        crnn_state[tt][tt2], [0, crnn_dim], [-1, crnn_dim])
                    h_gmlp = gmlp(h_crnn[tt][tt2])
                    if tt2 < num_ctrl_rnn_iter - 1:
                        crnn_glimpse_map[tt][
                            tt2 + 1] = tf.expand_dims(h_gmlp[-1], 2)

                ctrl_out = cmlp(h_crnn[tt][-1])[-1]

            attn_ctr_norm[tt] = tf.slice(ctrl_out, [0, 0], [-1, 2])
            attn_lg_size[tt] = tf.slice(ctrl_out, [0, 2], [-1, 2])

            # Restrict to (-1, 1), (-inf, 0)
            if squash_ctrl_params:
                attn_ctr_norm[tt] = tf.tanh(attn_ctr_norm[tt])
                attn_lg_size[tt] = -tf.nn.softplus(attn_lg_size[tt])

            attn_ctr[tt], attn_size[tt] = base.get_unnormalized_attn(
                attn_ctr_norm[tt], attn_lg_size[tt], inp_height, inp_width)
            attn_lg_var[tt] = tf.zeros(tf.pack([num_ex, 2]))
            attn_box_lg_gamma[tt] = tf.slice(ctrl_out, [0, 7], [-1, 1])
            attn_box_gamma[tt] = tf.reshape(tf.exp(
                attn_box_lg_gamma[tt]), [-1, 1, 1, 1])
            attn_top_left[tt], attn_bot_right[tt] = base.get_box_coord(
                attn_ctr[tt], attn_size[tt])

            # Initial filters (predicted)
            filter_y = base.get_gaussian_filter(
                attn_ctr[tt][:, 0], attn_size[tt][:, 0],
                attn_lg_var[tt][:, 0], inp_height, filter_height)
            filter_x = base.get_gaussian_filter(
                attn_ctr[tt][:, 1], attn_size[tt][:, 1],
                attn_lg_var[tt][:, 1], inp_width, filter_width)
            filter_y_inv = tf.transpose(filter_y, [0, 2, 1])
            filter_x_inv = tf.transpose(filter_x, [0, 2, 1])

            # Attention box
            attn_box[tt] = base.extract_patch(
                const_ones * attn_box_gamma[tt],
                filter_y_inv, filter_x_inv, 1)
            attn_box[tt] = tf.sigmoid(attn_box[tt] + attn_box_beta)
            attn_box[tt] = tf.reshape(attn_box[tt],
                                      [-1, 1, inp_height, inp_width])

            if fixed_order:
                _y_out = tf.expand_dims(y_gt[:, tt, :, :], 3)
            else:
                iou_soft_box[tt] = base.f_inter(
                    attn_box[tt], attn_box_gt) / \
                    base.f_union(attn_box[tt], attn_box_gt, eps=1e-5)
                grd_match = base.f_greedy_match(
                    iou_soft_box[tt], grd_match_cum)
                grd_match = tf.expand_dims(tf.expand_dims(grd_match, 2), 3)
                _y_out = tf.expand_dims(tf.reduce_sum(grd_match * y_gt, 1), 3)

            # Add independent uniform noise to groundtruth.
            _noise = tf.random_uniform(
                tf.pack([num_ex, inp_height, inp_width, 1]), 0, 0.3)
            _y_out = _y_out - _y_out * _noise
            canvas = tf.stop_gradient(tf.maximum(_y_out, canvas))
            # canvas += tf.stop_gradient(_y_out)

            # Scoring network
            if ctrl_rnn_inp_struct == 'dense':
                s_out[tt] = smlp(h_crnn[tt])[-1]
            elif ctrl_rnn_inp_struct == 'attn':
                s_out[tt] = smlp(h_crnn[tt][-1])[-1]

#########################
# Model outputs
#########################
        s_out = tf.concat(1, s_out)
        model['s_out'] = s_out
        attn_box = tf.concat(1, attn_box)
        model['attn_box'] = attn_box
        attn_top_left = tf.concat(1, [tf.expand_dims(tmp, 1)
                                      for tmp in attn_top_left])
        attn_bot_right = tf.concat(1, [tf.expand_dims(tmp, 1)
                                       for tmp in attn_bot_right])
        attn_ctr = tf.concat(1, [tf.expand_dims(tmp, 1) for tmp in attn_ctr])
        attn_size = tf.concat(1, [tf.expand_dims(tmp, 1) for tmp in attn_size])
        model['attn_top_left'] = attn_top_left
        model['attn_bot_right'] = attn_bot_right
        model['attn_ctr'] = attn_ctr
        model['attn_size'] = attn_size
        model['attn_ctr_norm_gt'] = attn_ctr_norm_gt
        model['attn_lg_size_gt'] = attn_lg_size_gt
        model['attn_top_left_gt'] = attn_top_left_gt
        model['attn_bot_right_gt'] = attn_bot_right_gt
        model['attn_box_gt'] = attn_box_gt
        attn_ctr_norm = tf.concat(1, [tf.expand_dims(tmp, 1)
                                      for tmp in attn_ctr_norm])
        attn_lg_size = tf.concat(1, [tf.expand_dims(tmp, 1)
                                     for tmp in attn_lg_size])
        model['attn_ctr_norm'] = attn_ctr_norm
        model['attn_lg_size'] = attn_lg_size

        attn_params = tf.concat(2, [attn_ctr_norm, attn_lg_size])
        attn_params_gt = tf.concat(2, [attn_ctr_norm_gt, attn_lg_size_gt])

#########################
# Loss function
#########################
        y_gt_shape = tf.shape(y_gt)
        num_ex_f = tf.to_float(y_gt_shape[0])
        max_num_obj = tf.to_float(y_gt_shape[1])

############################
# Box loss
############################
        if fixed_order:
            # [B, T] for fixed order.
            iou_soft_box = base.f_iou(attn_box, attn_box_gt, pairwise=False)
        else:
            # [B, T, T] for matching.
            iou_soft_box = tf.concat(1, [tf.expand_dims(iou_soft_box[tt], 1)
                                         for tt in xrange(timespan)])

        identity_match = base.get_identity_match(num_ex, timespan, s_gt)
        if fixed_order:
            match_box = identity_match
        else:
            match_box = base.f_segm_match(iou_soft_box, s_gt)

        model['match_box'] = match_box
        match_sum_box = tf.reduce_sum(match_box, reduction_indices=[2])
        match_count_box = tf.reduce_sum(
            match_sum_box, reduction_indices=[1])
        match_count_box = tf.maximum(1.0, match_count_box)

        # [B] if fixed order, [B, T] if matching.
        if fixed_order:
            iou_soft_box_mask = iou_soft_box
        else:
            iou_soft_box_mask = tf.reduce_sum(iou_soft_box * match_box, [1])
        iou_soft_box = tf.reduce_sum(iou_soft_box_mask, [1])
        iou_soft_box = tf.reduce_sum(
            iou_soft_box / match_count_box) / num_ex_f

        if box_loss_fn == 'mse':
            box_loss = base.f_match_loss(
                attn_params, attn_params_gt, match_box, timespan,
                base.f_squared_err, model=model)
        elif box_loss_fn == 'huber':
            box_loss = base.f_match_loss(
                attn_params, attn_params_gt, match_box, timespan, base.f_huber)
        if box_loss_fn == 'iou':
            box_loss = -iou_soft_box
        elif box_loss_fn == 'wt_iou':
            box_loss = -wt_iou_soft_box
        elif box_loss_fn == 'wt_cov':
            box_loss = -base.f_weighted_coverage(iou_soft_box, box_map_gt)
        elif box_loss_fn == 'bce':
            box_loss = base.f_match_loss(
                box_map, box_map_gt, match_box, timespan, base.f_bce)
        else:
            raise Exception('Unknown box_loss_fn: {}'.format(box_loss_fn))
        model['box_loss'] = box_loss

        box_loss_coeff = tf.constant(1.0)
        model['box_loss_coeff'] = box_loss_coeff
        tf.add_to_collection('losses', box_loss_coeff * box_loss)

####################
# Score loss
####################
        conf_loss = base.f_conf_loss(
            s_out, match_box, timespan, use_cum_min=True)
        model['conf_loss'] = conf_loss
        conf_loss_coeff = tf.constant(1.0)
        tf.add_to_collection('losses', conf_loss_coeff * conf_loss)

####################
# Total loss
####################
        total_loss = tf.add_n(tf.get_collection('losses'), name='total_loss')
        model['loss'] = total_loss

####################
# Optimizer
####################
        learn_rate = tf.train.exponential_decay(
            base_learn_rate, global_step, steps_per_learn_rate_decay,
            learn_rate_decay, staircase=True)
        model['learn_rate'] = learn_rate
        eps = 1e-7
        train_step = GradientClipOptimizer(
            tf.train.AdamOptimizer(learn_rate, epsilon=eps),
            clip=clip_gradient).minimize(total_loss, global_step=global_step)
        model['train_step'] = train_step

####################
# Glimpse
####################
        # T * T2 * [B, H' * W'] => [B, T, T2, H', W']
        if ctrl_rnn_inp_struct == 'attn':
            crnn_glimpse_map = tf.concat(
                1, [tf.expand_dims(tf.concat(
                    1, [tf.expand_dims(crnn_glimpse_map[tt][tt2], 1)
                        for tt2 in xrange(num_ctrl_rnn_iter)]), 1)
                    for tt in xrange(timespan)])
            crnn_glimpse_map = tf.reshape(
                crnn_glimpse_map, [-1, timespan, num_ctrl_rnn_iter, crnn_h,
                                   crnn_w])
            model['ctrl_rnn_glimpse_map'] = crnn_glimpse_map

        return model
Beispiel #10
0
def get_model(opt, device='/cpu:0'):
    """The attention model"""
    model = {}

    timespan = opt['timespan']
    inp_height = opt['inp_height']
    inp_width = opt['inp_width']
    inp_depth = opt['inp_depth']
    padding = opt['padding']
    filter_height = opt['filter_height']
    filter_width = opt['filter_width']

    ctrl_cnn_filter_size = opt['ctrl_cnn_filter_size']
    ctrl_cnn_depth = opt['ctrl_cnn_depth']
    ctrl_cnn_pool = opt['ctrl_cnn_pool']
    ctrl_rnn_hid_dim = opt['ctrl_rnn_hid_dim']

    num_ctrl_mlp_layers = opt['num_ctrl_mlp_layers']
    ctrl_mlp_dim = opt['ctrl_mlp_dim']

    # New parameters for double attention.
    num_ctrl_rnn_iter = opt['num_ctrl_rnn_iter']
    num_glimpse_mlp_layers = opt['num_glimpse_mlp_layers']

    attn_cnn_filter_size = opt['attn_cnn_filter_size']
    attn_cnn_depth = opt['attn_cnn_depth']
    attn_cnn_pool = opt['attn_cnn_pool']
    attn_dcnn_filter_size = opt['attn_dcnn_filter_size']
    attn_dcnn_depth = opt['attn_dcnn_depth']
    attn_dcnn_pool = opt['attn_dcnn_pool']
    attn_rnn_hid_dim = opt['attn_rnn_hid_dim']

    mlp_dropout_ratio = opt['mlp_dropout']

    num_attn_mlp_layers = opt['num_attn_mlp_layers']
    attn_mlp_depth = opt['attn_mlp_depth']
    attn_box_padding_ratio = opt['attn_box_padding_ratio']

    wd = opt['weight_decay']
    use_bn = opt['use_bn']
    use_gt_attn = opt['use_gt_attn']
    segm_loss_fn = opt['segm_loss_fn']
    box_loss_fn = opt['box_loss_fn']
    loss_mix_ratio = opt['loss_mix_ratio']
    base_learn_rate = opt['base_learn_rate']
    learn_rate_decay = opt['learn_rate_decay']
    steps_per_learn_rate_decay = opt['steps_per_learn_rate_decay']
    use_attn_rnn = opt['use_attn_rnn']
    use_knob = opt['use_knob']
    knob_base = opt['knob_base']
    knob_decay = opt['knob_decay']
    steps_per_knob_decay = opt['steps_per_knob_decay']
    use_canvas = opt['use_canvas']
    knob_box_offset = opt['knob_box_offset']
    knob_segm_offset = opt['knob_segm_offset']
    knob_use_timescale = opt['knob_use_timescale']
    gt_selector = opt['gt_selector']
    gt_box_ctr_noise = opt['gt_box_ctr_noise']
    gt_box_pad_noise = opt['gt_box_pad_noise']
    gt_segm_noise = opt['gt_segm_noise']
    downsample_canvas = opt['downsample_canvas']
    pretrain_cnn = opt['pretrain_cnn']
    cnn_share_weights = opt['cnn_share_weights']
    squash_ctrl_params = opt['squash_ctrl_params']
    use_iou_box = opt['use_iou_box']
    clip_gradient = opt['clip_gradient']

    rnd_hflip = opt['rnd_hflip']
    rnd_vflip = opt['rnd_vflip']
    rnd_transpose = opt['rnd_transpose']
    rnd_colour = opt['rnd_colour']

    ############################
    # Input definition
    ############################
    with tf.device(get_device_fn(device)):
        # Input image, [B, H, W, D]
        x = tf.placeholder('float', [None, inp_height, inp_width, inp_depth])
        x_shape = tf.shape(x)
        num_ex = x_shape[0]

        # Groundtruth segmentation, [B, T, H, W]
        y_gt = tf.placeholder('float', [None, timespan, inp_height, inp_width])

        # Groundtruth confidence score, [B, T]
        s_gt = tf.placeholder('float', [None, timespan])

        # Whether in training stage.
        phase_train = tf.placeholder('bool')
        phase_train_f = tf.to_float(phase_train)
        model['x'] = x
        model['y_gt'] = y_gt
        model['s_gt'] = s_gt
        model['phase_train'] = phase_train

        # Global step
        global_step = tf.Variable(0.0)

        ###############################
        # Random input transformation
        ###############################
        x, y_gt = img.random_transformation(x,
                                            y_gt,
                                            padding,
                                            phase_train,
                                            rnd_hflip=rnd_hflip,
                                            rnd_vflip=rnd_vflip,
                                            rnd_transpose=rnd_transpose,
                                            rnd_colour=rnd_colour)
        model['x_trans'] = x
        model['y_gt_trans'] = y_gt

        ############################
        # Canvas: external memory
        ############################
        if use_canvas:
            canvas = tf.zeros(tf.pack([num_ex, inp_height, inp_width, 1]))
            ccnn_inp_depth = inp_depth + 1
            acnn_inp_depth = inp_depth + 1
        else:
            ccnn_inp_depth = inp_depth
            acnn_inp_depth = inp_depth

############################
# Controller CNN definition
############################
        ccnn_filters = ctrl_cnn_filter_size
        ccnn_nlayers = len(ccnn_filters)
        ccnn_channels = [ccnn_inp_depth] + ctrl_cnn_depth
        ccnn_pool = ctrl_cnn_pool
        ccnn_act = [tf.nn.relu] * ccnn_nlayers
        ccnn_use_bn = [use_bn] * ccnn_nlayers
        if pretrain_cnn:
            h5f = h5py.File(pretrain_cnn, 'r')
            ccnn_init_w = [{
                'w': h5f['cnn_w_{}'.format(ii)][:],
                'b': h5f['cnn_b_{}'.format(ii)][:]
            } for ii in xrange(ccnn_nlayers)]
            ccnn_frozen = True
        else:
            ccnn_init_w = None
            ccnn_frozen = None

        ccnn = nn.cnn(ccnn_filters,
                      ccnn_channels,
                      ccnn_pool,
                      ccnn_act,
                      ccnn_use_bn,
                      phase_train=phase_train,
                      wd=wd,
                      scope='ctrl_cnn',
                      model=model,
                      init_weights=ccnn_init_w,
                      frozen=ccnn_frozen)
        h_ccnn = [None] * timespan

        ############################
        # Controller RNN definition
        ############################
        ccnn_subsample = np.array(ccnn_pool).prod()
        crnn_h = inp_height / ccnn_subsample
        crnn_w = inp_width / ccnn_subsample
        crnn_dim = ctrl_rnn_hid_dim
        canvas_dim = inp_height * inp_width / (ccnn_subsample**2)
        glimpse_map_dim = crnn_h * crnn_w
        glimpse_feat_dim = ccnn_channels[-1]
        # crnn_inp_dim = crnn_h * crnn_w * ccnn_channels[-1]
        crnn_state = [None] * (timespan + 1)
        crnn_glimpse_map = [None] * timespan
        crnn_g_i = [None] * timespan
        crnn_g_f = [None] * timespan
        crnn_g_o = [None] * timespan
        h_crnn = [None] * timespan
        crnn_cell = nn.lstm(glimpse_feat_dim,
                            crnn_dim,
                            wd=wd,
                            scope='ctrl_lstm',
                            model=model)

        ############################
        # Glimpse MLP definition
        ############################
        gmlp_dims = [crnn_dim] * num_glimpse_mlp_layers + [glimpse_map_dim]
        gmlp_act = [tf.nn.relu] * \
            (num_glimpse_mlp_layers - 1) + [tf.nn.softmax]
        gmlp_dropout = None
        gmlp = nn.mlp(gmlp_dims,
                      gmlp_act,
                      add_bias=True,
                      dropout_keep=gmlp_dropout,
                      phase_train=phase_train,
                      wd=wd,
                      scope='glimpse_mlp',
                      model=model)

        ############################
        # Controller MLP definition
        ############################
        cmlp_dims = [crnn_dim] + [ctrl_mlp_dim] * \
            (num_ctrl_mlp_layers - 1) + [9]
        cmlp_act = [tf.nn.relu] * (num_ctrl_mlp_layers - 1) + [None]
        cmlp_dropout = None
        # cmlp_dropout = [1.0 - mlp_dropout_ratio] * num_ctrl_mlp_layers
        cmlp = nn.mlp(cmlp_dims,
                      cmlp_act,
                      add_bias=True,
                      dropout_keep=cmlp_dropout,
                      phase_train=phase_train,
                      wd=wd,
                      scope='ctrl_mlp',
                      model=model)

        ############################
        # Attention CNN definition
        ############################
        acnn_filters = attn_cnn_filter_size
        acnn_nlayers = len(acnn_filters)
        acnn_channels = [acnn_inp_depth] + attn_cnn_depth
        acnn_pool = attn_cnn_pool
        acnn_act = [tf.nn.relu] * acnn_nlayers
        acnn_use_bn = [use_bn] * acnn_nlayers
        if cnn_share_weights:
            ccnn_shared_weights = []
            for ii in xrange(ccnn_nlayers):
                ccnn_shared_weights.append({
                    'w':
                    model['ctrl_cnn_w_{}'.format(ii)],
                    'b':
                    model['ctrl_cnn_b_{}'.format(ii)]
                })
        else:
            ccnn_shared_weights = None
        acnn = nn.cnn(acnn_filters,
                      acnn_channels,
                      acnn_pool,
                      acnn_act,
                      acnn_use_bn,
                      phase_train=phase_train,
                      wd=wd,
                      scope='attn_cnn',
                      model=model,
                      shared_weights=ccnn_shared_weights)

        x_patch = [None] * timespan
        h_acnn = [None] * timespan
        h_acnn_last = [None] * timespan

        ############################
        # Attention RNN definition
        ############################
        acnn_subsample = np.array(acnn_pool).prod()
        arnn_h = filter_height / acnn_subsample
        arnn_w = filter_width / acnn_subsample

        if use_attn_rnn:
            arnn_dim = attn_rnn_hid_dim
            arnn_inp_dim = arnn_h * arnn_w * acnn_channels[-1]
            arnn_state = [None] * (timespan + 1)
            arnn_g_i = [None] * timespan
            arnn_g_f = [None] * timespan
            arnn_g_o = [None] * timespan
            arnn_state[-1] = tf.zeros(tf.pack([num_ex, arnn_dim * 2]))
            arnn_cell = nn.lstm(arnn_inp_dim,
                                arnn_dim,
                                wd=wd,
                                scope='attn_lstm')
            amlp_inp_dim = arnn_dim
        else:
            amlp_inp_dim = arnn_h * arnn_w * acnn_channels[-1]

############################
# Attention MLP definition
############################
        core_depth = attn_mlp_depth
        core_dim = arnn_h * arnn_w * core_depth
        amlp_dims = [amlp_inp_dim] + [core_dim] * num_attn_mlp_layers
        amlp_act = [tf.nn.relu] * num_attn_mlp_layers
        amlp_dropout = None
        # amlp_dropout = [1.0 - mlp_dropout_ratio] * num_attn_mlp_layers
        amlp = nn.mlp(amlp_dims,
                      amlp_act,
                      dropout_keep=amlp_dropout,
                      phase_train=phase_train,
                      wd=wd,
                      scope='attn_mlp',
                      model=model)

        # DCNN [B, RH, RW, MD] => [B, A, A, 1]
        adcnn_filters = attn_dcnn_filter_size
        adcnn_nlayers = len(adcnn_filters)
        adcnn_unpool = attn_dcnn_pool
        adcnn_act = [tf.nn.relu] * adcnn_nlayers
        adcnn_channels = [attn_mlp_depth] + attn_dcnn_depth
        adcnn_use_bn = [use_bn] * adcnn_nlayers
        adcnn_skip_ch = [0] + acnn_channels[::-1][1:]
        adcnn = nn.dcnn(adcnn_filters,
                        adcnn_channels,
                        adcnn_unpool,
                        adcnn_act,
                        use_bn=adcnn_use_bn,
                        skip_ch=adcnn_skip_ch,
                        phase_train=phase_train,
                        wd=wd,
                        model=model,
                        scope='attn_dcnn')
        h_adcnn = [None] * timespan

        ##########################
        # Score MLP definition
        ##########################
        smlp = nn.mlp([crnn_dim, 1], [tf.sigmoid],
                      wd=wd,
                      scope='score_mlp',
                      model=model)
        s_out = [None] * timespan

        ##########################
        # Attention box
        ##########################
        attn_ctr_norm = [None] * timespan
        attn_lg_size = [None] * timespan
        attn_ctr = [None] * timespan
        attn_size = [None] * timespan
        attn_lg_var = [None] * timespan
        attn_lg_gamma = [None] * timespan
        attn_gamma = [None] * timespan
        attn_box_lg_gamma = [None] * timespan
        attn_top_left = [None] * timespan
        attn_bot_right = [None] * timespan
        attn_box = [None] * timespan
        iou_soft_box = [None] * timespan
        const_ones = tf.ones(tf.pack([num_ex, filter_height, filter_width, 1]))
        attn_box_beta = tf.constant([-5.0])
        attn_box_gamma = [None] * timespan

        #############################
        # Groundtruth attention box
        #############################
        # [B, T, 2]
        attn_ctr_gt, attn_size_gt, attn_lg_var_gt, attn_box_gt, \
            attn_top_left_gt, attn_bot_right_gt = \
            base.get_gt_attn(y_gt,
                             padding_ratio=attn_box_padding_ratio,
                             center_shift_ratio=0.0)
        attn_ctr_gt_noise, attn_size_gt_noise, attn_lg_var_gt_noise, \
            attn_box_gt_noise, \
            attn_top_left_gt_noise, attn_bot_right_gt_noise = \
            base.get_gt_attn(y_gt,
                             padding_ratio=tf.random_uniform(
                                 tf.pack([num_ex, timespan, 1]),
                                 attn_box_padding_ratio - gt_box_pad_noise,
                                 attn_box_padding_ratio + gt_box_pad_noise),
                             center_shift_ratio=tf.random_uniform(
                                 tf.pack([num_ex, timespan, 2]),
                                 -gt_box_ctr_noise, gt_box_ctr_noise))

        ##########################
        # Groundtruth mix
        ##########################
        grd_match_cum = tf.zeros(tf.pack([num_ex, timespan]))
        # Add a bias on every entry so there is no duplicate match
        # [1, N]
        iou_bias_eps = 1e-7
        iou_bias = tf.expand_dims(
            tf.to_float(tf.reverse(tf.range(timespan), [True])) * iou_bias_eps,
            0)

        # Scale mix ratio on different timesteps.
        gt_knob_time_scale = tf.reshape(
            1.0 + tf.log(1.0 + tf.to_float(tf.range(timespan)) * 3.0 *
                         float(knob_use_timescale)), [1, timespan, 1])

        # Mix in groundtruth box.
        global_step_box = tf.maximum(0.0, global_step - knob_box_offset)
        gt_knob_prob_box = tf.train.exponential_decay(knob_base,
                                                      global_step_box,
                                                      steps_per_knob_decay,
                                                      knob_decay,
                                                      staircase=False)
        gt_knob_prob_box = tf.minimum(1.0,
                                      gt_knob_prob_box * gt_knob_time_scale)
        gt_knob_box = tf.to_float(
            tf.random_uniform(tf.pack([num_ex, timespan, 1]), 0, 1.0) <=
            gt_knob_prob_box)
        model['gt_knob_prob_box'] = gt_knob_prob_box[0, 0, 0]

        # Mix in groundtruth segmentation.
        global_step_segm = tf.maximum(0.0, global_step - knob_segm_offset)
        gt_knob_prob_segm = tf.train.exponential_decay(knob_base,
                                                       global_step_segm,
                                                       steps_per_knob_decay,
                                                       knob_decay,
                                                       staircase=False)
        gt_knob_prob_segm = tf.minimum(1.0,
                                       gt_knob_prob_segm * gt_knob_time_scale)
        gt_knob_segm = tf.to_float(
            tf.random_uniform(tf.pack([num_ex, timespan, 1]), 0, 1.0) <=
            gt_knob_prob_segm)
        model['gt_knob_prob_segm'] = gt_knob_prob_segm[0, 0, 0]

        ##########################
        # Segmentation output
        ##########################
        y_out = [None] * timespan
        y_out_lg_gamma = [None] * timespan
        y_out_beta = tf.constant([-5.0])

        ##########################
        # Computation graph
        ##########################
        if not use_canvas:
            h_ccnn = ccnn(x)

        for tt in xrange(timespan):
            # Controller CNN [B, H, W, D] => [B, RH1, RW1, RD1]
            if use_canvas:
                ccnn_inp = tf.concat(3, [x, canvas])
                acnn_inp = ccnn_inp
                h_ccnn[tt] = ccnn(ccnn_inp)
                _h_ccnn = h_ccnn[tt]
            else:
                ccnn_inp = x
                acnn_inp = x
                _h_ccnn = h_ccnn

            h_ccnn_last = _h_ccnn[-1]
            # crnn_inp = tf.reshape(h_ccnn_last, [-1, crnn_inp_dim])
            crnn_inp = tf.reshape(h_ccnn_last,
                                  [-1, glimpse_map_dim, glimpse_feat_dim])

            crnn_state[tt] = [None] * (num_ctrl_rnn_iter + 1)
            crnn_g_i[tt] = [None] * num_ctrl_rnn_iter
            crnn_g_f[tt] = [None] * num_ctrl_rnn_iter
            crnn_g_o[tt] = [None] * num_ctrl_rnn_iter
            h_crnn[tt] = [None] * num_ctrl_rnn_iter

            crnn_state[tt][-1] = tf.zeros(tf.pack([num_ex, crnn_dim * 2]))
            # if tt == 0:
            #     crnn_state[tt][-1] = tf.zeros(tf.pack([num_ex, crnn_dim * 2]))
            # else:
            #     crnn_state[tt][-1] = crnn_state[tt - 1][num_ctrl_rnn_iter - 1]

            crnn_glimpse_map[tt] = [None] * num_ctrl_rnn_iter
            crnn_glimpse_map[tt][0] = tf.ones(
                tf.pack([num_ex, glimpse_map_dim, 1])) / glimpse_map_dim
            for tt2 in xrange(num_ctrl_rnn_iter):
                crnn_glimpse = tf.reduce_sum(
                    crnn_inp * crnn_glimpse_map[tt][tt2], [1])
                crnn_state[tt][tt2], crnn_g_i[tt][tt2], crnn_g_f[tt][tt2], \
                    crnn_g_o[tt][tt2] = \
                    crnn_cell(crnn_glimpse, crnn_state[tt][tt2 - 1])
                h_crnn[tt][tt2] = tf.slice(crnn_state[tt][tt2], [0, crnn_dim],
                                           [-1, crnn_dim])
                h_gmlp = gmlp(h_crnn[tt][tt2])
                if tt2 < num_ctrl_rnn_iter - 1:
                    crnn_glimpse_map[tt][tt2 + 1] = tf.expand_dims(
                        h_gmlp[-1], 2)

            ctrl_out = cmlp(h_crnn[tt][-1])[-1]

            attn_ctr_norm[tt] = tf.slice(ctrl_out, [0, 0], [-1, 2])
            attn_lg_size[tt] = tf.slice(ctrl_out, [0, 2], [-1, 2])

            # Restrict to (-1, 1), (-inf, 0)
            if squash_ctrl_params:
                attn_ctr_norm[tt] = tf.tanh(attn_ctr_norm[tt])
                attn_lg_size[tt] = -tf.nn.softplus(attn_lg_size[tt])

            attn_ctr[tt], attn_size[tt] = base.get_unnormalized_attn(
                attn_ctr_norm[tt], attn_lg_size[tt], inp_height, inp_width)
            attn_lg_var[tt] = tf.zeros(tf.pack([num_ex, 2]))
            attn_lg_gamma[tt] = tf.slice(ctrl_out, [0, 6], [-1, 1])
            attn_box_lg_gamma[tt] = tf.slice(ctrl_out, [0, 7], [-1, 1])
            y_out_lg_gamma[tt] = tf.slice(ctrl_out, [0, 8], [-1, 1])

            attn_gamma[tt] = tf.reshape(tf.exp(attn_lg_gamma[tt]),
                                        [-1, 1, 1, 1])
            attn_box_gamma[tt] = tf.reshape(tf.exp(attn_box_lg_gamma[tt]),
                                            [-1, 1, 1, 1])
            y_out_lg_gamma[tt] = tf.reshape(y_out_lg_gamma[tt], [-1, 1, 1, 1])

            # Initial filters (predicted)
            filter_y = get_gaussian_filter(attn_ctr[tt][:,
                                                        0], attn_size[tt][:,
                                                                          0],
                                           attn_lg_var[tt][:, 0], inp_height,
                                           filter_height)
            filter_x = get_gaussian_filter(attn_ctr[tt][:,
                                                        1], attn_size[tt][:,
                                                                          1],
                                           attn_lg_var[tt][:, 1], inp_width,
                                           filter_width)
            filter_y_inv = tf.transpose(filter_y, [0, 2, 1])
            filter_x_inv = tf.transpose(filter_x, [0, 2, 1])

            # Attention box
            if use_iou_box:
                _idx_map = get_idx_map(tf.pack([num_ex, inp_height,
                                                inp_width]))
                attn_top_left[tt], attn_bot_right[tt] = get_box_coord(
                    attn_ctr[tt], attn_size[tt])
                attn_box[tt] = get_filled_box_idx(_idx_map, attn_top_left[tt],
                                                  attn_bot_right[tt])
                attn_box[tt] = tf.reshape(attn_box[tt],
                                          [-1, 1, inp_height, inp_width])
            else:
                attn_box[tt] = extract_patch(const_ones * attn_box_gamma[tt],
                                             filter_y_inv, filter_x_inv, 1)
                attn_box[tt] = tf.sigmoid(attn_box[tt] + attn_box_beta)
                attn_box[tt] = tf.reshape(attn_box[tt],
                                          [-1, 1, inp_height, inp_width])

            # Here is the knob kick in GT bbox.
            if use_knob:
                # IOU [B, 1, T]
                # [B, 1, H, W] * [B, T, H, W] = [B, T]
                if use_iou_box:
                    _top_left = tf.expand_dims(attn_top_left[tt], 1)
                    _bot_right = tf.expand_dims(attn_bot_right[tt], 1)
                    iou_soft_box[tt] = f_iou_box(_top_left, _bot_right,
                                                 attn_top_left_gt,
                                                 attn_bot_right_gt)
                    iou_soft_box[tt] += iou_bias
                else:
                    iou_soft_box[tt] = f_inter(attn_box[tt], attn_box_gt) / \
                        f_union(attn_box[tt], attn_box_gt, eps=1e-5)

                grd_match = f_greedy_match(iou_soft_box[tt], grd_match_cum)

                if gt_selector == 'greedy_match':
                    # Add in the cumulative matching to not double count.
                    grd_match_cum += grd_match

                # [B, T, 1]
                grd_match = tf.expand_dims(grd_match, 2)
                attn_ctr_gt_match = tf.reduce_sum(
                    grd_match * attn_ctr_gt_noise, 1)
                attn_size_gt_match = tf.reduce_sum(
                    grd_match * attn_size_gt_noise, 1)

                _gt_knob_box = gt_knob_box
                attn_ctr[tt] = phase_train_f * _gt_knob_box[:, tt, 0: 1] * \
                    attn_ctr_gt_match + \
                    (1 - phase_train_f * _gt_knob_box[:, tt, 0: 1]) * \
                    attn_ctr[tt]
                attn_size[tt] = phase_train_f * _gt_knob_box[:, tt, 0: 1] * \
                    attn_size_gt_match + \
                    (1 - phase_train_f * _gt_knob_box[:, tt, 0: 1]) * \
                    attn_size[tt]

            attn_top_left[tt], attn_bot_right[tt] = get_box_coord(
                attn_ctr[tt], attn_size[tt])

            filter_y = get_gaussian_filter(attn_ctr[tt][:,
                                                        0], attn_size[tt][:,
                                                                          0],
                                           attn_lg_var[tt][:, 0], inp_height,
                                           filter_height)
            filter_x = get_gaussian_filter(attn_ctr[tt][:,
                                                        1], attn_size[tt][:,
                                                                          1],
                                           attn_lg_var[tt][:, 1], inp_width,
                                           filter_width)
            filter_y_inv = tf.transpose(filter_y, [0, 2, 1])
            filter_x_inv = tf.transpose(filter_x, [0, 2, 1])

            # Attended patch [B, A, A, D]
            x_patch[tt] = attn_gamma[tt] * extract_patch(
                acnn_inp, filter_y, filter_x, acnn_inp_depth)

            # CNN [B, A, A, D] => [B, RH2, RW2, RD2]
            h_acnn[tt] = acnn(x_patch[tt])
            h_acnn_last[tt] = h_acnn[tt][-1]

            if use_attn_rnn:
                # RNN [B, T, R2]
                arnn_inp = tf.reshape(h_acnn_last[tt], [-1, arnn_inp_dim])
                arnn_state[tt], arnn_g_i[tt], arnn_g_f[tt], arnn_g_o[tt] = \
                    arnn_cell(arnn_inp, arnn_state[tt - 1])

            # Scoring network
            s_out[tt] = smlp(h_crnn[tt][-1])[-1]

            # Dense segmentation network [B, R] => [B, M]
            if use_attn_rnn:
                h_arnn = tf.slice(arnn_state[tt], [0, arnn_dim],
                                  [-1, arnn_dim])
                amlp_inp = h_arnn
            else:
                amlp_inp = h_acnn_last[tt]
            amlp_inp = tf.reshape(amlp_inp, [-1, amlp_inp_dim])
            h_core = amlp(amlp_inp)[-1]
            h_core = tf.reshape(h_core, [-1, arnn_h, arnn_w, attn_mlp_depth])

            # DCNN
            skip = [None] + h_acnn[tt][::-1][1:] + [x_patch[tt]]
            h_adcnn[tt] = adcnn(h_core, skip=skip)

            # Output
            y_out[tt] = extract_patch(h_adcnn[tt][-1], filter_y_inv,
                                      filter_x_inv, 1)
            y_out[tt] = tf.exp(y_out_lg_gamma[tt]) * y_out[tt] + y_out_beta
            y_out[tt] = tf.sigmoid(y_out[tt])
            y_out[tt] = tf.reshape(y_out[tt], [-1, 1, inp_height, inp_width])

            # Here is the knob kick in GT segmentations at this timestep.
            # [B, N, 1, 1]
            if use_canvas:
                if use_knob:
                    _gt_knob_segm = tf.expand_dims(
                        tf.expand_dims(gt_knob_segm[:, tt, 0:1], 2), 3)
                    # [B, N, 1, 1]
                    grd_match = tf.expand_dims(grd_match, 3)
                    _y_out = tf.expand_dims(tf.reduce_sum(grd_match * y_gt, 1),
                                            3)
                    # Add independent uniform noise to groundtruth.
                    _noise = tf.random_uniform(
                        tf.pack([num_ex, inp_height, inp_width, 1]), 0,
                        gt_segm_noise)
                    _y_out = _y_out - _y_out * _noise
                    _y_out = phase_train_f * _gt_knob_segm * _y_out + \
                        (1 - phase_train_f * _gt_knob_segm) * \
                        tf.reshape(y_out[tt], [-1, inp_height, inp_width, 1])
                else:
                    _y_out = tf.reshape(y_out[tt],
                                        [-1, inp_height, inp_width, 1])
                canvas += tf.stop_gradient(_y_out)

#########################
# Model outputs
#########################
        s_out = tf.concat(1, s_out)
        model['s_out'] = s_out
        y_out = tf.concat(1, y_out)
        model['y_out'] = y_out
        attn_box = tf.concat(1, attn_box)
        model['attn_box'] = attn_box
        x_patch = tf.concat(
            1, [tf.expand_dims(x_patch[tt], 1) for tt in xrange(timespan)])
        model['x_patch'] = x_patch

        #########################
        # Loss function
        #########################
        y_gt_shape = tf.shape(y_gt)
        num_ex_f = tf.to_float(y_gt_shape[0])
        max_num_obj = tf.to_float(y_gt_shape[1])

        ############################
        # Box loss
        ############################
        if use_knob:
            iou_soft_box = tf.concat(1, [
                tf.expand_dims(iou_soft_box[tt], 1) for tt in xrange(timespan)
            ])
        else:
            iou_soft_box = f_iou(attn_box,
                                 attn_box_gt,
                                 timespan,
                                 pairwise=True)

        model['iou_soft_box'] = iou_soft_box
        model['attn_box_gt'] = attn_box_gt
        match_box = f_segm_match(iou_soft_box, s_gt)
        model['match_box'] = match_box
        match_sum_box = tf.reduce_sum(match_box, reduction_indices=[2])
        match_count_box = tf.reduce_sum(match_sum_box, reduction_indices=[1])
        match_count_box = tf.maximum(1.0, match_count_box)
        iou_soft_box_mask = tf.reduce_sum(iou_soft_box * match_box, [1])
        iou_soft_box = tf.reduce_sum(
            tf.reduce_sum(iou_soft_box_mask, [1]) / match_count_box) / num_ex_f
        gt_wt_box = f_coverage_weight(attn_box_gt)
        wt_iou_soft_box = tf.reduce_sum(
            tf.reduce_sum(iou_soft_box_mask * gt_wt_box, [1]) /
            match_count_box) / num_ex_f
        if box_loss_fn == 'iou':
            box_loss = -iou_soft_box
        elif box_loss_fn == 'wt_iou':
            box_loss = -wt_iou_soft_box
        elif box_loss_fn == 'wt_cov':
            box_loss = -f_weighted_coverage(iou_soft_box, attn_box_gt)
        elif box_loss_fn == 'bce':
            box_loss = f_match_bce(attn_box, attn_box_gt, match_box, timespan)
        else:
            raise Exception('Unknown box_loss_fn: {}'.format(box_loss_fn))
        model['box_loss'] = box_loss

        box_loss_coeff = tf.constant(1.0)
        tf.add_to_collection('losses', box_loss_coeff * box_loss)

        ##############################
        # Segmentation loss
        ##############################
        # IoU (soft)
        iou_soft = f_iou(y_out, y_gt, timespan, pairwise=True)
        match = f_segm_match(iou_soft, s_gt)
        model['match'] = match
        match_sum = tf.reduce_sum(match, reduction_indices=[2])
        match_count = tf.reduce_sum(match_sum, reduction_indices=[1])
        match_count = tf.maximum(1.0, match_count)

        # Weighted coverage (soft)
        wt_cov_soft = f_weighted_coverage(iou_soft, y_gt)
        model['wt_cov_soft'] = wt_cov_soft
        unwt_cov_soft = f_unweighted_coverage(iou_soft, match_count)
        model['unwt_cov_soft'] = unwt_cov_soft

        # IOU (soft)
        iou_soft_mask = tf.reduce_sum(iou_soft * match, [1])
        iou_soft = tf.reduce_sum(
            tf.reduce_sum(iou_soft_mask, [1]) / match_count) / num_ex_f
        model['iou_soft'] = iou_soft
        gt_wt = f_coverage_weight(y_gt)
        wt_iou_soft = tf.reduce_sum(
            tf.reduce_sum(iou_soft_mask * gt_wt, [1]) / match_count) / num_ex_f
        model['wt_iou_soft'] = wt_iou_soft

        if segm_loss_fn == 'iou':
            segm_loss = -iou_soft
        elif segm_loss_fn == 'wt_iou':
            segm_loss = -wt_iou_soft
        elif segm_loss_fn == 'wt_cov':
            segm_loss = -wt_cov_soft
        elif segm_loss_fn == 'bce':
            segm_loss = f_match_bce(y_out, y_gt, match, timespan)
        else:
            raise Exception('Unknown segm_loss_fn: {}'.format(segm_loss_fn))
        model['segm_loss'] = segm_loss
        segm_loss_coeff = tf.constant(1.0)
        tf.add_to_collection('losses', segm_loss_coeff * segm_loss)

        ####################
        # Score loss
        ####################
        conf_loss = f_conf_loss(s_out, match, timespan, use_cum_min=True)
        model['conf_loss'] = conf_loss
        tf.add_to_collection('losses', loss_mix_ratio * conf_loss)

        ####################
        # Total loss
        ####################
        total_loss = tf.add_n(tf.get_collection('losses'), name='total_loss')
        model['loss'] = total_loss

        ####################
        # Optimizer
        ####################
        learn_rate = tf.train.exponential_decay(base_learn_rate,
                                                global_step,
                                                steps_per_learn_rate_decay,
                                                learn_rate_decay,
                                                staircase=True)
        model['learn_rate'] = learn_rate
        eps = 1e-7

        train_step = GradientClipOptimizer(tf.train.AdamOptimizer(learn_rate,
                                                                  epsilon=eps),
                                           clip=clip_gradient).minimize(
                                               total_loss,
                                               global_step=global_step)
        model['train_step'] = train_step

        ####################
        # Statistics
        ####################
        # [B, M, N] * [B, M, N] => [B] * [B] => [1]
        y_out_hard = tf.to_float(y_out > 0.5)
        iou_hard = f_iou(y_out_hard, y_gt, timespan, pairwise=True)
        wt_cov_hard = f_weighted_coverage(iou_hard, y_gt)
        model['wt_cov_hard'] = wt_cov_hard
        unwt_cov_hard = f_unweighted_coverage(iou_hard, match_count)
        model['unwt_cov_hard'] = unwt_cov_hard
        # [B, T]
        iou_hard_mask = tf.reduce_sum(iou_hard * match, [1])
        iou_hard = tf.reduce_sum(
            tf.reduce_sum(iou_hard_mask, [1]) / match_count) / num_ex_f
        model['iou_hard'] = iou_hard
        wt_iou_hard = tf.reduce_sum(
            tf.reduce_sum(iou_hard_mask * gt_wt, [1]) / match_count) / num_ex_f
        model['wt_iou_hard'] = wt_iou_hard

        dice = f_dice(y_out_hard, y_gt, timespan, pairwise=True)
        dice = tf.reduce_sum(
            tf.reduce_sum(dice * match, [1, 2]) / match_count) / num_ex_f
        model['dice'] = dice

        model['count_acc'] = f_count_acc(s_out, s_gt)
        model['dic'] = f_dic(s_out, s_gt, abs=False)
        model['dic_abs'] = f_dic(s_out, s_gt, abs=True)

        ################################
        # Controller output statistics
        ################################
        attn_top_left = tf.concat(
            1, [tf.expand_dims(tmp, 1) for tmp in attn_top_left])
        attn_bot_right = tf.concat(
            1, [tf.expand_dims(tmp, 1) for tmp in attn_bot_right])
        attn_ctr = tf.concat(1, [tf.expand_dims(tmp, 1) for tmp in attn_ctr])
        attn_size = tf.concat(1, [tf.expand_dims(tmp, 1) for tmp in attn_size])
        attn_lg_gamma = tf.concat(
            1, [tf.expand_dims(tmp, 1) for tmp in attn_lg_gamma])
        attn_box_lg_gamma = tf.concat(
            1, [tf.expand_dims(tmp, 1) for tmp in attn_box_lg_gamma])
        y_out_lg_gamma = tf.concat(
            1, [tf.expand_dims(tmp, 1) for tmp in y_out_lg_gamma])
        attn_lg_gamma_mean = tf.reduce_sum(attn_lg_gamma) / num_ex_f / timespan
        attn_box_lg_gamma_mean = tf.reduce_sum(
            attn_box_lg_gamma) / num_ex_f / timespan
        y_out_lg_gamma_mean = tf.reduce_sum(
            y_out_lg_gamma) / num_ex_f / timespan
        model['attn_ctr'] = attn_ctr
        model['attn_size'] = attn_size
        model['attn_top_left'] = attn_top_left
        model['attn_bot_right'] = attn_bot_right
        model['attn_lg_gamma_mean'] = attn_lg_gamma_mean
        model['attn_box_lg_gamma_mean'] = attn_box_lg_gamma_mean
        model['y_out_lg_gamma_mean'] = y_out_lg_gamma_mean

    return model
Beispiel #11
0
def get_model(opt, device='/cpu:0'):
    """The box model"""
    model = {}

    timespan = opt['timespan']
    inp_height = opt['inp_height']
    inp_width = opt['inp_width']
    inp_depth = opt['inp_depth']
    padding = opt['padding']
    filter_height = opt['filter_height']
    filter_width = opt['filter_width']

    ctrl_cnn_filter_size = opt['ctrl_cnn_filter_size']
    ctrl_cnn_depth = opt['ctrl_cnn_depth']
    ctrl_cnn_pool = opt['ctrl_cnn_pool']
    ctrl_rnn_hid_dim = opt['ctrl_rnn_hid_dim']

    num_ctrl_mlp_layers = opt['num_ctrl_mlp_layers']
    ctrl_mlp_dim = opt['ctrl_mlp_dim']

    attn_box_padding_ratio = opt['attn_box_padding_ratio']

    wd = opt['weight_decay']
    use_bn = opt['use_bn']
    box_loss_fn = opt['box_loss_fn']
    base_learn_rate = opt['base_learn_rate']
    learn_rate_decay = opt['learn_rate_decay']
    steps_per_learn_rate_decay = opt['steps_per_learn_rate_decay']
    pretrain_cnn = opt['pretrain_cnn']
    squash_ctrl_params = opt['squash_ctrl_params']
    clip_gradient = opt['clip_gradient']
    fixed_order = opt['fixed_order']
    ctrl_rnn_inp_struct = opt['ctrl_rnn_inp_struct']  # dense or attn
    num_ctrl_rnn_iter = opt['num_ctrl_rnn_iter']
    num_glimpse_mlp_layers = opt['num_glimpse_mlp_layers']

    rnd_hflip = opt['rnd_hflip']
    rnd_vflip = opt['rnd_vflip']
    rnd_transpose = opt['rnd_transpose']
    rnd_colour = opt['rnd_colour']

    ############################
    # Input definition
    ############################
    with tf.device(base.get_device_fn(device)):
        # Input image, [B, H, W, D]
        x = tf.placeholder('float', [None, inp_height, inp_width, inp_depth],
                           name='x')
        x_shape = tf.shape(x)
        num_ex = x_shape[0]

        # Groundtruth segmentation, [B, T, H, W]
        y_gt = tf.placeholder('float', [None, timespan, inp_height, inp_width],
                              name='y_gt')

        # Groundtruth confidence score, [B, T]
        s_gt = tf.placeholder('float', [None, timespan], name='s_gt')

        # Whether in training stage.
        phase_train = tf.placeholder('bool', name='phase_train')
        phase_train_f = tf.to_float(phase_train)
        model['x'] = x
        model['y_gt'] = y_gt
        model['s_gt'] = s_gt
        model['phase_train'] = phase_train

        # Global step
        global_step = tf.Variable(0.0, name='global_step')

        ###############################
        # Random input transformation
        ###############################
        x, y_gt = img.random_transformation(x,
                                            y_gt,
                                            padding,
                                            phase_train,
                                            rnd_hflip=rnd_hflip,
                                            rnd_vflip=rnd_vflip,
                                            rnd_transpose=rnd_transpose,
                                            rnd_colour=rnd_colour)
        model['x_trans'] = x
        model['y_gt_trans'] = y_gt

        ############################
        # Canvas: external memory
        ############################
        canvas = tf.zeros(tf.pack([num_ex, inp_height, inp_width, 1]))
        ccnn_inp_depth = inp_depth + 1
        acnn_inp_depth = inp_depth + 1

        ############################
        # Controller CNN definition
        ############################
        ccnn_filters = ctrl_cnn_filter_size
        ccnn_nlayers = len(ccnn_filters)
        ccnn_channels = [ccnn_inp_depth] + ctrl_cnn_depth
        ccnn_pool = ctrl_cnn_pool
        ccnn_act = [tf.nn.relu] * ccnn_nlayers
        ccnn_use_bn = [use_bn] * ccnn_nlayers

        if pretrain_cnn:
            log.info('Loading pretrained weights from {}'.format(pretrain_cnn))
            h5f = h5py.File(pretrain_cnn, 'r')
            acnn_nlayers = 0
            # Assuming acnn_nlayers is smaller than ccnn_nlayers.
            for ii in xrange(ccnn_nlayers):
                if 'attn_cnn_w_{}'.format(ii) in h5f:
                    log.info('Loading attn_cnn_w_{}'.format(ii))
                    log.info('Loading attn_cnn_b_{}'.format(ii))
                    acnn_nlayers += 1
            ccnn_init_w = [{
                'w': h5f['attn_cnn_w_{}'.format(ii)][:],
                'b': h5f['attn_cnn_b_{}'.format(ii)][:]
            } for ii in xrange(acnn_nlayers)]
            for ii in xrange(acnn_nlayers):
                for tt in xrange(timespan):
                    for w in ['beta', 'gamma']:
                        ccnn_init_w[ii]['{}_{}'.format(
                            w,
                            tt)] = h5f['attn_cnn_{}_{}_{}'.format(ii, tt,
                                                                  w)][:]
            ccnn_frozen = [True] * acnn_nlayers
            for ii in xrange(acnn_nlayers, ccnn_nlayers):
                ccnn_init_w.append(None)
                ccnn_frozen.append(False)
        else:
            ccnn_init_w = None
            ccnn_frozen = None

        ccnn = nn.cnn(ccnn_filters,
                      ccnn_channels,
                      ccnn_pool,
                      ccnn_act,
                      ccnn_use_bn,
                      phase_train=phase_train,
                      wd=wd,
                      scope='ctrl_cnn',
                      model=model,
                      init_weights=ccnn_init_w,
                      frozen=ccnn_frozen)
        h_ccnn = [None] * timespan

        ############################
        # Controller RNN definition
        ############################
        ccnn_subsample = np.array(ccnn_pool).prod()
        crnn_h = inp_height / ccnn_subsample
        crnn_w = inp_width / ccnn_subsample
        crnn_dim = ctrl_rnn_hid_dim
        canvas_dim = inp_height * inp_width / (ccnn_subsample**2)

        glimpse_map_dim = crnn_h * crnn_w
        glimpse_feat_dim = ccnn_channels[-1]
        if ctrl_rnn_inp_struct == 'dense':
            crnn_inp_dim = crnn_h * crnn_w * ccnn_channels[-1]
        elif ctrl_rnn_inp_struct == 'attn':
            crnn_inp_dim = glimpse_feat_dim

        crnn_state = [None] * (timespan + 1)
        crnn_glimpse_map = [None] * timespan
        crnn_g_i = [None] * timespan
        crnn_g_f = [None] * timespan
        crnn_g_o = [None] * timespan
        h_crnn = [None] * timespan
        crnn_state[-1] = tf.zeros(tf.pack([num_ex, crnn_dim * 2]))
        crnn_cell = nn.lstm(crnn_inp_dim,
                            crnn_dim,
                            wd=wd,
                            scope='ctrl_lstm',
                            model=model)

        ############################
        # Glimpse MLP definition
        ############################
        gmlp_dims = [crnn_dim] * num_glimpse_mlp_layers + [glimpse_map_dim]
        gmlp_act = [tf.nn.relu] * \
            (num_glimpse_mlp_layers - 1) + [tf.nn.softmax]
        gmlp_dropout = None
        gmlp = nn.mlp(gmlp_dims,
                      gmlp_act,
                      add_bias=True,
                      dropout_keep=gmlp_dropout,
                      phase_train=phase_train,
                      wd=wd,
                      scope='glimpse_mlp',
                      model=model)

        ############################
        # Controller MLP definition
        ############################
        cmlp_dims = [crnn_dim] + [ctrl_mlp_dim] * \
            (num_ctrl_mlp_layers - 1) + [9]
        cmlp_act = [tf.nn.relu] * (num_ctrl_mlp_layers - 1) + [None]
        cmlp_dropout = None
        cmlp = nn.mlp(cmlp_dims,
                      cmlp_act,
                      add_bias=True,
                      dropout_keep=cmlp_dropout,
                      phase_train=phase_train,
                      wd=wd,
                      scope='ctrl_mlp',
                      model=model)

        ##########################
        # Score MLP definition
        ##########################
        smlp = nn.mlp([crnn_dim, 1], [tf.sigmoid], wd=wd, scope='score_mlp')
        s_out = [None] * timespan

        ##########################
        # Attention box
        ##########################
        attn_ctr_norm = [None] * timespan
        attn_lg_size = [None] * timespan
        attn_lg_var = [None] * timespan
        attn_ctr = [None] * timespan
        attn_size = [None] * timespan
        attn_top_left = [None] * timespan
        attn_bot_right = [None] * timespan
        attn_box = [None] * timespan
        attn_box_lg_gamma = [None] * timespan
        attn_box_gamma = [None] * timespan
        const_ones = tf.ones(tf.pack([num_ex, filter_height, filter_width, 1]))
        attn_box_beta = tf.constant([-5.0])
        iou_soft_box = [None] * timespan

        #############################
        # Groundtruth attention box
        #############################
        attn_top_left_gt, attn_bot_right_gt, attn_box_gt = base.get_gt_box(
            y_gt, padding_ratio=attn_box_padding_ratio, center_shift_ratio=0.0)
        attn_ctr_gt, attn_size_gt = base.get_box_ctr_size(
            attn_top_left_gt, attn_bot_right_gt)
        attn_ctr_norm_gt = base.get_normalized_center(attn_ctr_gt, inp_height,
                                                      inp_width)
        attn_lg_size_gt = base.get_normalized_size(attn_size_gt, inp_height,
                                                   inp_width)

        ##########################
        # Groundtruth mix
        ##########################
        grd_match_cum = tf.zeros(tf.pack([num_ex, timespan]))

        ##########################
        # Computation graph
        ##########################
        for tt in xrange(timespan):
            # Controller CNN
            ccnn_inp = tf.concat(3, [x, canvas])
            acnn_inp = ccnn_inp
            h_ccnn[tt] = ccnn(ccnn_inp)
            _h_ccnn = h_ccnn[tt]
            h_ccnn_last = _h_ccnn[-1]

            # Controller RNN [B, R1]
            if ctrl_rnn_inp_struct == 'dense':
                crnn_inp = tf.reshape(h_ccnn_last, [-1, crnn_inp_dim])
                crnn_state[tt], crnn_g_i[tt], crnn_g_f[tt], crnn_g_o[tt] = \
                    crnn_cell(crnn_inp, crnn_state[tt - 1])
                h_crnn[tt] = tf.slice(crnn_state[tt], [0, crnn_dim],
                                      [-1, crnn_dim])

                ctrl_out = cmlp(h_crnn[tt])[-1]

            elif ctrl_rnn_inp_struct == 'attn':
                crnn_inp = tf.reshape(h_ccnn_last,
                                      [-1, glimpse_map_dim, glimpse_feat_dim])
                crnn_state[tt] = [None] * (num_ctrl_rnn_iter + 1)
                crnn_g_i[tt] = [None] * num_ctrl_rnn_iter
                crnn_g_f[tt] = [None] * num_ctrl_rnn_iter
                crnn_g_o[tt] = [None] * num_ctrl_rnn_iter
                h_crnn[tt] = [None] * num_ctrl_rnn_iter

                crnn_state[tt][-1] = tf.zeros(tf.pack([num_ex, crnn_dim * 2]))

                crnn_glimpse_map[tt] = [None] * num_ctrl_rnn_iter
                crnn_glimpse_map[tt][0] = tf.ones(
                    tf.pack([num_ex, glimpse_map_dim, 1])) / glimpse_map_dim

                # Inner glimpse RNN
                for tt2 in xrange(num_ctrl_rnn_iter):
                    crnn_glimpse = tf.reduce_sum(
                        crnn_inp * crnn_glimpse_map[tt][tt2], [1])
                    crnn_state[tt][tt2], crnn_g_i[tt][tt2], crnn_g_f[tt][tt2], \
                        crnn_g_o[tt][tt2] = \
                        crnn_cell(crnn_glimpse, crnn_state[tt][tt2 - 1])
                    h_crnn[tt][tt2] = tf.slice(crnn_state[tt][tt2],
                                               [0, crnn_dim], [-1, crnn_dim])
                    h_gmlp = gmlp(h_crnn[tt][tt2])
                    if tt2 < num_ctrl_rnn_iter - 1:
                        crnn_glimpse_map[tt][tt2 + 1] = tf.expand_dims(
                            h_gmlp[-1], 2)

                ctrl_out = cmlp(h_crnn[tt][-1])[-1]

            attn_ctr_norm[tt] = tf.slice(ctrl_out, [0, 0], [-1, 2])
            attn_lg_size[tt] = tf.slice(ctrl_out, [0, 2], [-1, 2])

            # Restrict to (-1, 1), (-inf, 0)
            if squash_ctrl_params:
                attn_ctr_norm[tt] = tf.tanh(attn_ctr_norm[tt])
                attn_lg_size[tt] = -tf.nn.softplus(attn_lg_size[tt])

            attn_ctr[tt], attn_size[tt] = base.get_unnormalized_attn(
                attn_ctr_norm[tt], attn_lg_size[tt], inp_height, inp_width)
            attn_lg_var[tt] = tf.zeros(tf.pack([num_ex, 2]))
            attn_box_lg_gamma[tt] = tf.slice(ctrl_out, [0, 7], [-1, 1])
            attn_box_gamma[tt] = tf.reshape(tf.exp(attn_box_lg_gamma[tt]),
                                            [-1, 1, 1, 1])
            attn_top_left[tt], attn_bot_right[tt] = base.get_box_coord(
                attn_ctr[tt], attn_size[tt])

            # Initial filters (predicted)
            filter_y = base.get_gaussian_filter(attn_ctr[tt][:, 0],
                                                attn_size[tt][:, 0],
                                                attn_lg_var[tt][:, 0],
                                                inp_height, filter_height)
            filter_x = base.get_gaussian_filter(attn_ctr[tt][:, 1],
                                                attn_size[tt][:, 1],
                                                attn_lg_var[tt][:, 1],
                                                inp_width, filter_width)
            filter_y_inv = tf.transpose(filter_y, [0, 2, 1])
            filter_x_inv = tf.transpose(filter_x, [0, 2, 1])

            # Attention box
            attn_box[tt] = base.extract_patch(const_ones * attn_box_gamma[tt],
                                              filter_y_inv, filter_x_inv, 1)
            attn_box[tt] = tf.sigmoid(attn_box[tt] + attn_box_beta)
            attn_box[tt] = tf.reshape(attn_box[tt],
                                      [-1, 1, inp_height, inp_width])

            if fixed_order:
                _y_out = tf.expand_dims(y_gt[:, tt, :, :], 3)
            else:
                iou_soft_box[tt] = base.f_inter(
                    attn_box[tt], attn_box_gt) / \
                    base.f_union(attn_box[tt], attn_box_gt, eps=1e-5)
                grd_match = base.f_greedy_match(iou_soft_box[tt],
                                                grd_match_cum)
                grd_match = tf.expand_dims(tf.expand_dims(grd_match, 2), 3)
                _y_out = tf.expand_dims(tf.reduce_sum(grd_match * y_gt, 1), 3)

            # Add independent uniform noise to groundtruth.
            _noise = tf.random_uniform(
                tf.pack([num_ex, inp_height, inp_width, 1]), 0, 0.3)
            _y_out = _y_out - _y_out * _noise
            canvas = tf.stop_gradient(tf.maximum(_y_out, canvas))
            # canvas += tf.stop_gradient(_y_out)

            # Scoring network
            if ctrl_rnn_inp_struct == 'dense':
                s_out[tt] = smlp(h_crnn[tt])[-1]
            elif ctrl_rnn_inp_struct == 'attn':
                s_out[tt] = smlp(h_crnn[tt][-1])[-1]

#########################
# Model outputs
#########################
        s_out = tf.concat(1, s_out)
        model['s_out'] = s_out
        attn_box = tf.concat(1, attn_box)
        model['attn_box'] = attn_box
        attn_top_left = tf.concat(
            1, [tf.expand_dims(tmp, 1) for tmp in attn_top_left])
        attn_bot_right = tf.concat(
            1, [tf.expand_dims(tmp, 1) for tmp in attn_bot_right])
        attn_ctr = tf.concat(1, [tf.expand_dims(tmp, 1) for tmp in attn_ctr])
        attn_size = tf.concat(1, [tf.expand_dims(tmp, 1) for tmp in attn_size])
        model['attn_top_left'] = attn_top_left
        model['attn_bot_right'] = attn_bot_right
        model['attn_ctr'] = attn_ctr
        model['attn_size'] = attn_size
        model['attn_ctr_norm_gt'] = attn_ctr_norm_gt
        model['attn_lg_size_gt'] = attn_lg_size_gt
        model['attn_top_left_gt'] = attn_top_left_gt
        model['attn_bot_right_gt'] = attn_bot_right_gt
        model['attn_box_gt'] = attn_box_gt
        attn_ctr_norm = tf.concat(
            1, [tf.expand_dims(tmp, 1) for tmp in attn_ctr_norm])
        attn_lg_size = tf.concat(
            1, [tf.expand_dims(tmp, 1) for tmp in attn_lg_size])
        model['attn_ctr_norm'] = attn_ctr_norm
        model['attn_lg_size'] = attn_lg_size

        attn_params = tf.concat(2, [attn_ctr_norm, attn_lg_size])
        attn_params_gt = tf.concat(2, [attn_ctr_norm_gt, attn_lg_size_gt])

        #########################
        # Loss function
        #########################
        y_gt_shape = tf.shape(y_gt)
        num_ex_f = tf.to_float(y_gt_shape[0])
        max_num_obj = tf.to_float(y_gt_shape[1])

        ############################
        # Box loss
        ############################
        if fixed_order:
            # [B, T] for fixed order.
            iou_soft_box = base.f_iou(attn_box, attn_box_gt, pairwise=False)
        else:
            # [B, T, T] for matching.
            iou_soft_box = tf.concat(1, [
                tf.expand_dims(iou_soft_box[tt], 1) for tt in xrange(timespan)
            ])

        identity_match = base.get_identity_match(num_ex, timespan, s_gt)
        if fixed_order:
            match_box = identity_match
        else:
            match_box = base.f_segm_match(iou_soft_box, s_gt)

        model['match_box'] = match_box
        match_sum_box = tf.reduce_sum(match_box, reduction_indices=[2])
        match_count_box = tf.reduce_sum(match_sum_box, reduction_indices=[1])
        match_count_box = tf.maximum(1.0, match_count_box)

        # [B] if fixed order, [B, T] if matching.
        if fixed_order:
            iou_soft_box_mask = iou_soft_box
        else:
            iou_soft_box_mask = tf.reduce_sum(iou_soft_box * match_box, [1])
        iou_soft_box = tf.reduce_sum(iou_soft_box_mask, [1])
        iou_soft_box = tf.reduce_sum(iou_soft_box / match_count_box) / num_ex_f

        if box_loss_fn == 'mse':
            box_loss = base.f_match_loss(attn_params,
                                         attn_params_gt,
                                         match_box,
                                         timespan,
                                         base.f_squared_err,
                                         model=model)
        elif box_loss_fn == 'huber':
            box_loss = base.f_match_loss(attn_params, attn_params_gt,
                                         match_box, timespan, base.f_huber)
        if box_loss_fn == 'iou':
            box_loss = -iou_soft_box
        elif box_loss_fn == 'wt_iou':
            box_loss = -wt_iou_soft_box
        elif box_loss_fn == 'wt_cov':
            box_loss = -base.f_weighted_coverage(iou_soft_box, box_map_gt)
        elif box_loss_fn == 'bce':
            box_loss = base.f_match_loss(box_map, box_map_gt, match_box,
                                         timespan, base.f_bce)
        else:
            raise Exception('Unknown box_loss_fn: {}'.format(box_loss_fn))
        model['box_loss'] = box_loss

        box_loss_coeff = tf.constant(1.0)
        model['box_loss_coeff'] = box_loss_coeff
        tf.add_to_collection('losses', box_loss_coeff * box_loss)

        ####################
        # Score loss
        ####################
        conf_loss = base.f_conf_loss(s_out,
                                     match_box,
                                     timespan,
                                     use_cum_min=True)
        model['conf_loss'] = conf_loss
        conf_loss_coeff = tf.constant(1.0)
        tf.add_to_collection('losses', conf_loss_coeff * conf_loss)

        ####################
        # Total loss
        ####################
        total_loss = tf.add_n(tf.get_collection('losses'), name='total_loss')
        model['loss'] = total_loss

        ####################
        # Optimizer
        ####################
        learn_rate = tf.train.exponential_decay(base_learn_rate,
                                                global_step,
                                                steps_per_learn_rate_decay,
                                                learn_rate_decay,
                                                staircase=True)
        model['learn_rate'] = learn_rate
        eps = 1e-7
        train_step = GradientClipOptimizer(tf.train.AdamOptimizer(learn_rate,
                                                                  epsilon=eps),
                                           clip=clip_gradient).minimize(
                                               total_loss,
                                               global_step=global_step)
        model['train_step'] = train_step

        ####################
        # Glimpse
        ####################
        # T * T2 * [B, H' * W'] => [B, T, T2, H', W']
        if ctrl_rnn_inp_struct == 'attn':
            crnn_glimpse_map = tf.concat(1, [
                tf.expand_dims(
                    tf.concat(1, [
                        tf.expand_dims(crnn_glimpse_map[tt][tt2], 1)
                        for tt2 in xrange(num_ctrl_rnn_iter)
                    ]), 1) for tt in xrange(timespan)
            ])
            crnn_glimpse_map = tf.reshape(
                crnn_glimpse_map,
                [-1, timespan, num_ctrl_rnn_iter, crnn_h, crnn_w])
            model['ctrl_rnn_glimpse_map'] = crnn_glimpse_map

        return model
def get_model(opt, device='/cpu:0'):
    model = {}
    inp_height = opt['inp_height']
    inp_width = opt['inp_width']
    inp_depth = opt['inp_depth']
    cnn_filter_size = opt['cnn_filter_size']
    cnn_depth = opt['cnn_depth']
    cnn_pool = opt['cnn_pool']
    mlp_dims = opt['mlp_dims']
    mlp_dropout = opt['mlp_dropout']
    wd = opt['weight_decay']
    base_learn_rate = opt['base_learn_rate']
    learn_rate_decay = opt['learn_rate_decay']
    steps_per_learn_rate_decay = opt['steps_per_learn_rate_decay']

############################
# Input definition
############################
    with tf.device(get_device_fn(device)):
        x1 = tf.placeholder(
            'float', [None, inp_height, inp_width, inp_depth], name='x1')
        x2 = tf.placeholder(
            'float', [None, inp_height, inp_width, inp_depth], name='x2')
        phase_train = tf.placeholder('bool', name='phase_train')
        y_gt = tf.placeholder('float', [None], name='y_gt')
        global_step = tf.Variable(0.0)

############################
# Feature CNN definition
############################
        cnn_channels = [inp_depth] + cnn_depth
        cnn_nlayers = len(cnn_filter_size)
        cnn_use_bn = [True] * cnn_nlayers
        cnn_act = [tf.nn.relu] * cnn_nlayers
        cnn = nn.cnn(cnn_filter_size, cnn_channels, cnn_pool, cnn_act,
                     cnn_use_bn, phase_train=phase_train, wd=wd, scope='cnn')

        subsample = np.array(cnn_pool).prod()
        cnn_h = inp_height / subsample
        cnn_w = inp_width / subsample
        feat_dim = cnn_h * cnn_w * cnn_channels[-1]

############################
# Matching MLP definition
############################
        mlp_nlayers = len(mlp_dims)
        mlp_dims = [2 * feat_dim] + mlp_dims
        mlp_dropout_keep = [1 - mlp_dropout] * mlp_nlayers
        mlp_act = [tf.nn.relu] * (mlp_nlayers - 1) + [tf.sigmoid]
        mlp = nn.mlp(mlp_dims, mlp_act,
                     dropout_keep=mlp_dropout_keep, phase_train=phase_train)

############################
# Computation graph
############################
        f1 = cnn(x1)
        f1 = tf.reshape(f1[-1], [-1, feat_dim])
        f2 = cnn(x2)
        f2 = tf.reshape(f2[-1], [-1, feat_dim])
        f_join = tf.concat(1, [f1, f2])
        y_out = mlp(f_join)[-1]
        y_out = tf.reshape(y_out, [-1])

############################
# Loss function
############################
        num_ex = tf.shape(y_gt)[0]
        num_ex_f = tf.to_float(num_ex)
        bce = f_bce(y_out, y_gt)
        bce = tf.reduce_sum(bce) / num_ex_f
        tf.add_to_collection('losses', bce)
        total_loss = tf.add_n(tf.get_collection('losses'), name='total_loss')

############################
# Statistics
############################
        y_out_thresh = tf.to_float(y_out > 0.5)
        acc = tf.reduce_sum(
            tf.to_float(tf.equal(y_out_thresh, y_gt))) / num_ex_f

####################
# Optimizer
####################
        learn_rate = tf.train.exponential_decay(
            base_learn_rate, global_step, steps_per_learn_rate_decay,
            learn_rate_decay, staircase=True)
        eps = 1e-7
        train_step = tf.train.AdamOptimizer(learn_rate, epsilon=eps).minimize(
            total_loss, global_step=global_step)

############################
# Computation nodes
############################
        model['x1'] = x1
        model['x2'] = x2
        model['y_gt'] = y_gt
        model['phase_train'] = phase_train
        model['y_out'] = y_out
        model['loss'] = total_loss
        model['acc'] = acc
        model['learn_rate'] = learn_rate
        model['train_step'] = train_step

    return model
def get_model(opt, device='/cpu:0'):
    """The original model"""
    model = {}

    timespan = opt['timespan']
    inp_height = opt['inp_height']
    inp_width = opt['inp_width']
    inp_depth = opt['inp_depth']
    padding = opt['padding']
    filter_height = opt['filter_height']
    filter_width = opt['filter_width']

    ctrl_cnn_filter_size = opt['ctrl_cnn_filter_size']
    ctrl_cnn_depth = opt['ctrl_cnn_depth']
    ctrl_cnn_pool = opt['ctrl_cnn_pool']
    ctrl_rnn_hid_dim = opt['ctrl_rnn_hid_dim']

    num_ctrl_mlp_layers = opt['num_ctrl_mlp_layers']
    ctrl_mlp_dim = opt['ctrl_mlp_dim']

    attn_cnn_filter_size = opt['attn_cnn_filter_size']
    attn_cnn_depth = opt['attn_cnn_depth']
    attn_dcnn_filter_size = opt['attn_dcnn_filter_size']
    attn_dcnn_depth = opt['attn_dcnn_depth']
    attn_dcnn_pool = opt['attn_dcnn_pool']
    attn_rnn_hid_dim = opt['attn_rnn_hid_dim']

    mlp_dropout_ratio = opt['mlp_dropout']

    num_attn_mlp_layers = opt['num_attn_mlp_layers']
    attn_mlp_depth = opt['attn_mlp_depth']
    attn_box_padding_ratio = opt['attn_box_padding_ratio']

    wd = opt['weight_decay']
    use_bn = opt['use_bn']
    use_gt_attn = opt['use_gt_attn']
    segm_loss_fn = opt['segm_loss_fn']
    box_loss_fn = opt['box_loss_fn']
    loss_mix_ratio = opt['loss_mix_ratio']
    base_learn_rate = opt['base_learn_rate']
    learn_rate_decay = opt['learn_rate_decay']
    steps_per_learn_rate_decay = opt['steps_per_learn_rate_decay']
    use_attn_rnn = opt['use_attn_rnn']
    use_knob = opt['use_knob']
    knob_base = opt['knob_base']
    knob_decay = opt['knob_decay']
    steps_per_knob_decay = opt['steps_per_knob_decay']
    use_canvas = opt['use_canvas']
    knob_box_offset = opt['knob_box_offset']
    knob_segm_offset = opt['knob_segm_offset']
    knob_use_timescale = opt['knob_use_timescale']
    gt_box_ctr_noise = opt['gt_box_ctr_noise']
    gt_box_pad_noise = opt['gt_box_pad_noise']
    gt_segm_noise = opt['gt_segm_noise']
    squash_ctrl_params = opt['squash_ctrl_params']
    use_iou_box = opt['use_iou_box']
    fixed_order = opt['fixed_order']

    rnd_hflip = opt['rnd_hflip']
    rnd_vflip = opt['rnd_vflip']
    rnd_transpose = opt['rnd_transpose']
    rnd_colour = opt['rnd_colour']

############################
# Input definition
############################
    with tf.device(base.get_device_fn(device)):
        # Input image, [B, H, W, D]
        x = tf.placeholder('float', [None, inp_height, inp_width, inp_depth])
        x_shape = tf.shape(x)
        num_ex = x_shape[0]

        # Groundtruth segmentation, [B, T, H, W]
        y_gt = tf.placeholder('float', [None, timespan, inp_height, inp_width])

        # Groundtruth confidence score, [B, T]
        s_gt = tf.placeholder('float', [None, timespan])

        # Whether in training stage.
        phase_train = tf.placeholder('bool')
        phase_train_f = tf.to_float(phase_train)
        model['x'] = x
        model['y_gt'] = y_gt
        model['s_gt'] = s_gt
        model['phase_train'] = phase_train

        # Global step
        global_step = tf.Variable(0.0)

        # Random image transformation
        x, y_gt = img.random_transformation(
            x, y_gt, padding, phase_train,
            rnd_hflip=rnd_hflip, rnd_vflip=rnd_vflip,
            rnd_transpose=rnd_transpose, rnd_colour=rnd_colour)
        model['x_trans'] = x
        model['y_gt_trans'] = y_gt

############################
# Canvas: external memory
############################
        if use_canvas:
            canvas = tf.zeros(tf.pack([num_ex, inp_height, inp_width, 1]))
            ccnn_inp_depth = inp_depth + 1
            # ccnn_inp_depth = inp_depth
            acnn_inp_depth = inp_depth + 1
        else:
            ccnn_inp_depth = inp_depth
            acnn_inp_depth = inp_depth

############################
# Controller CNN definition
############################
        ccnn_filters = ctrl_cnn_filter_size
        ccnn_nlayers = len(ccnn_filters)
        ccnn_channels = [ccnn_inp_depth] + ctrl_cnn_depth
        ccnn_pool = ctrl_cnn_pool
        ccnn_act = [tf.nn.relu] * ccnn_nlayers
        ccnn_use_bn = [use_bn] * ccnn_nlayers
        ccnn = nn.cnn(ccnn_filters, ccnn_channels, ccnn_pool, ccnn_act,
                      ccnn_use_bn, phase_train=phase_train, wd=wd,
                      scope='ctrl_cnn', model=model)
        h_ccnn = [None] * timespan

############################
# Controller RNN definition
############################
        ccnn_subsample = np.array(ccnn_pool).prod()
        crnn_h = inp_height / ccnn_subsample
        crnn_w = inp_width / ccnn_subsample
        crnn_dim = ctrl_rnn_hid_dim
        crnn_inp_dim = crnn_h * crnn_w * ccnn_channels[-1]
        crnn_state = [None] * (timespan + 1)
        crnn_g_i = [None] * timespan
        crnn_g_f = [None] * timespan
        crnn_g_o = [None] * timespan
        h_crnn = [None] * timespan
        crnn_state[-1] = tf.zeros(tf.pack([num_ex, crnn_dim * 2]))
        crnn_cell = nn.lstm(crnn_inp_dim, crnn_dim, wd=wd, scope='ctrl_lstm',
                            model=model)

############################
# Controller MLP definition
############################
        cmlp_dims = [crnn_dim] + [ctrl_mlp_dim] * \
            (num_ctrl_mlp_layers - 1) + [9]
        cmlp_act = [tf.nn.relu] * (num_ctrl_mlp_layers - 1) + [None]
        cmlp_dropout = None
        cmlp = nn.mlp(cmlp_dims, cmlp_act, add_bias=True,
                      dropout_keep=cmlp_dropout,
                      phase_train=phase_train, wd=wd, scope='ctrl_mlp',
                      model=model)

##########################
# Attention CNN definition
##########################
        acnn_filters = attn_cnn_filter_size
        acnn_nlayers = len(acnn_filters)
        acnn_channels = [acnn_inp_depth] + attn_cnn_depth
        acnn_pool = [2] * acnn_nlayers
        acnn_act = [tf.nn.relu] * acnn_nlayers
        acnn_use_bn = [use_bn] * acnn_nlayers
        acnn = nn.cnn(acnn_filters, acnn_channels, acnn_pool, acnn_act,
                      acnn_use_bn, phase_train=phase_train, wd=wd,
                      scope='attn_cnn', model=model)

        x_patch = [None] * timespan
        h_acnn = [None] * timespan
        h_acnn_last = [None] * timespan

##########################
# Attention RNN definition
##########################
        acnn_subsample = np.array(acnn_pool).prod()
        arnn_h = filter_height / acnn_subsample
        arnn_w = filter_width / acnn_subsample

        if use_attn_rnn:
            arnn_dim = attn_rnn_hid_dim
            arnn_inp_dim = arnn_h * arnn_w * acnn_channels[-1]
            arnn_state = [None] * (timespan + 1)
            arnn_g_i = [None] * timespan
            arnn_g_f = [None] * timespan
            arnn_g_o = [None] * timespan
            arnn_state[-1] = tf.zeros(tf.pack([num_ex, arnn_dim * 2]))
            arnn_cell = nn.lstm(arnn_inp_dim, arnn_dim,
                                wd=wd, scope='attn_lstm')
            amlp_inp_dim = arnn_dim
        else:
            amlp_inp_dim = arnn_h * arnn_w * acnn_channels[-1]

##########################
# Attention MLP definition
##########################
        core_depth = attn_mlp_depth
        core_dim = arnn_h * arnn_w * core_depth
        amlp_dims = [amlp_inp_dim] + [core_dim] * num_attn_mlp_layers
        amlp_act = [tf.nn.relu] * num_attn_mlp_layers
        amlp_dropout = None
        # amlp_dropout = [1.0 - mlp_dropout_ratio] * num_attn_mlp_layers
        amlp = nn.mlp(amlp_dims, amlp_act, dropout_keep=amlp_dropout,
                      phase_train=phase_train, wd=wd, scope='attn_mlp',
                      model=model)

##########################
# Score MLP definition
##########################
        smlp = nn.mlp([crnn_dim + core_dim, 1], [tf.sigmoid], wd=wd,
                      scope='score_mlp', model=model)
        s_out = [None] * timespan

#############################
# Attention DCNN definition
#############################
        adcnn_filters = attn_dcnn_filter_size
        adcnn_nlayers = len(adcnn_filters)
        adcnn_unpool = [2] * (adcnn_nlayers - 1) + [1]
        adcnn_act = [tf.nn.relu] * adcnn_nlayers
        adcnn_channels = [attn_mlp_depth] + attn_dcnn_depth
        adcnn_use_bn = [use_bn] * dcnn_nlayers
        adcnn_skip_ch = [0] + acnn_channels[::-1][1:] + [ccnn_inp_depth]
        adcnn = nn.dcnn(adcnn_filters, adcnn_channels, adcnn_unpool,
                        adcnn_act, use_bn=adcnn_use_bn, skip_ch=adcnn_skip_ch,
                        phase_train=phase_train, wd=wd, model=model)
        h_adcnn = [None] * timespan

##########################
# Attention box
##########################
        attn_box = [None] * timespan
        iou_soft_box = [None] * timespan
        const_ones = tf.ones(tf.pack([num_ex, filter_height, filter_width, 1]))
        attn_box_beta = -5.0


#############################
# Groundtruth attention box
#############################
        # [B, T, 2]
        attn_ctr_gt, attn_size_gt, attn_lg_var_gt, attn_box_gt, \
            attn_top_left_gt, attn_bot_right_gt = \
            base.get_gt_attn(y_gt,
                             padding_ratio=attn_box_padding_ratio,
                             center_shift_ratio=0.0)
        attn_ctr_gt_noise, attn_size_gt_noise, attn_lg_var_gt_noise, \
            attn_box_gt_noise, \
            attn_top_left_gt_noise, attn_bot_right_gt_noise = \
            base.get_gt_attn(y_gt,
                             padding_ratio=tf.random_uniform(
                                 tf.pack([num_ex, timespan, 1]),
                                 attn_box_padding_ratio - gt_box_pad_noise,
                                 attn_box_padding_ratio + gt_box_pad_noise),
                             center_shift_ratio=tf.random_uniform(
                                 tf.pack([num_ex, timespan, 2]),
                                 -gt_box_ctr_noise, gt_box_ctr_noise))
        attn_delta_gt = _get_delta_from_size(
            attn_size_gt, filter_height, filter_width)
        attn_delta_gt_noise = _get_delta_from_size(
            attn_size_gt_noise, filter_height, filter_width)

        attn_lg_gamma_gt = tf.zeros(tf.pack([num_ex, timespan, 1]))
        attn_box_lg_gamma_gt = tf.zeros(tf.pack([num_ex, timespan, 1]))
        y_out_lg_gamma_gt = tf.zeros(tf.pack([num_ex, timespan, 1]))
        gtbox_top_left = [None] * timespan
        gtbox_bot_right = [None] * timespan

        attn_ctr = [None] * timespan
        attn_ctr_norm = [None] * timespan
        attn_delta = [None] * timespan
        attn_lg_size = [None] * timespan
        attn_size = [None] * timespan
        attn_lg_var = [None] * timespan
        attn_lg_gamma = [None] * timespan
        attn_gamma = [None] * timespan
        attn_box_lg_gamma = [None] * timespan
        attn_box_gamma = [None] * timespan
        attn_top_left = [None] * timespan
        attn_bot_right = [None] * timespan

##########################
# Groundtruth mix
##########################
        grd_match_cum = tf.zeros(tf.pack([num_ex, timespan]))

        # Scale mix ratio on different timesteps.
        if knob_use_timescale:
            gt_knob_time_scale = tf.reshape(
                1.0 + tf.log(1.0 + tf.to_float(tf.range(timespan)) * 3.0),
                [1, timespan, 1])
        else:
            gt_knob_time_scale = tf.ones([1, timespan, 1])

        # Mix in groundtruth box.
        global_step_box = tf.maximum(0.0, global_step - knob_box_offset)
        gt_knob_prob_box = tf.train.exponential_decay(
            knob_base, global_step_box, steps_per_knob_decay, knob_decay,
            staircase=False)
        gt_knob_prob_box = tf.minimum(
            1.0, gt_knob_prob_box * gt_knob_time_scale)
        gt_knob_box = tf.to_float(tf.random_uniform(
            tf.pack([num_ex, timespan, 1]), 0, 1.0) <= gt_knob_prob_box)
        model['gt_knob_prob_box'] = gt_knob_prob_box[0, 0, 0]

        # Mix in groundtruth segmentation.
        global_step_segm = tf.maximum(0.0, global_step - knob_segm_offset)
        gt_knob_prob_segm = tf.train.exponential_decay(
            knob_base, global_step_segm, steps_per_knob_decay, knob_decay,
            staircase=False)
        gt_knob_prob_segm = tf.minimum(
            1.0, gt_knob_prob_segm * gt_knob_time_scale)
        gt_knob_segm = tf.to_float(tf.random_uniform(
            tf.pack([num_ex, timespan, 1]), 0, 1.0) <= gt_knob_prob_segm)
        model['gt_knob_prob_segm'] = gt_knob_prob_segm[0, 0, 0]

##########################
# Segmentation output
##########################
        y_out = [None] * timespan
        y_out_lg_gamma = [None] * timespan
        y_out_beta = -5.0

##########################
# Computation graph
##########################
        for tt in xrange(timespan):
            # Controller CNN
            if use_canvas:
                ccnn_inp = tf.concat(3, [x, canvas])
                acnn_inp = ccnn_inp
            else:
                ccnn_inp = x
                acnn_inp = x
            h_ccnn[tt] = ccnn(ccnn_inp)
            h_ccnn_last = h_ccnn[tt][-1]
            crnn_inp = tf.reshape(h_ccnn_last, [-1, crnn_inp_dim])

            # Controller RNN [B, R1]
            crnn_state[tt], crnn_g_i[tt], crnn_g_f[tt], crnn_g_o[tt] = \
                crnn_cell(crnn_inp, crnn_state[tt - 1])
            h_crnn[tt] = tf.slice(
                crnn_state[tt], [0, crnn_dim], [-1, crnn_dim])

            if use_gt_attn:
                attn_ctr[tt] = attn_ctr_gt[:, tt, :]
                attn_delta[tt] = attn_delta_gt[:, tt, :]
                attn_lg_var[tt] = attn_lg_var_gt[:, tt, :]
                attn_lg_gamma[tt] = attn_lg_gamma_gt[:, tt, :]
                attn_box_lg_gamma[tt] = attn_box_lg_gamma_gt[:, tt, :]
                y_out_lg_gamma[tt] = y_out_lg_gamma_gt[:, tt, :]
            else:
                ctrl_out = cmlp(h_crnn[tt])[-1]
                attn_ctr_norm[tt] = tf.slice(ctrl_out, [0, 0], [-1, 2])
                attn_lg_size[tt] = tf.slice(ctrl_out, [0, 2], [-1, 2])
                if squash_ctrl_params:
                    # Restrict to (-1, 1)
                    attn_ctr_norm[tt] = tf.tanh(attn_ctr_norm[tt])
                    # Restrict to (-inf, 0)
                    attn_lg_size[tt] = -tf.nn.softplus(attn_lg_size[tt])

                attn_ctr[tt], attn_size[tt] = base.get_unnormalized_attn(
                    attn_ctr_norm[tt], attn_lg_size[tt], inp_height, inp_width)
                attn_delta[tt] = _get_delta_from_size(
                    attn_size[tt], filter_height, filter_width)
                attn_lg_var[tt] = tf.zeros(tf.pack([num_ex, 2]))
                attn_lg_gamma[tt] = tf.slice(ctrl_out, [0, 6], [-1, 1])
                attn_box_lg_gamma[tt] = tf.slice(ctrl_out, [0, 7], [-1, 1])
                y_out_lg_gamma[tt] = tf.slice(ctrl_out, [0, 8], [-1, 1])

            attn_gamma[tt] = tf.reshape(
                tf.exp(attn_lg_gamma[tt]), [-1, 1, 1, 1])
            attn_box_gamma[tt] = tf.reshape(
                tf.exp(attn_box_lg_gamma[tt]), [-1, 1, 1, 1])
            y_out_lg_gamma[tt] = tf.reshape(y_out_lg_gamma[tt], [-1, 1, 1, 1])

            # Initial filters (predicted)
            filter_y = base.get_gaussian_filter(
                attn_ctr[tt][:, 0], attn_size[tt][:, 0],
                attn_lg_var[tt][:, 0], inp_height, filter_height)
            filter_x = base.get_gaussian_filter(
                attn_ctr[tt][:, 1], attn_size[tt][:, 1],
                attn_lg_var[tt][:, 1], inp_width, filter_width)
            filter_y_inv = tf.transpose(filter_y, [0, 2, 1])
            filter_x_inv = tf.transpose(filter_x, [0, 2, 1])

            # Attention box
            if use_iou_box:
                _idx_map = base.get_idx_map(
                    tf.pack([num_ex, inp_height, inp_width]))
                attn_top_left[tt], attn_bot_right[tt] = _get_attn_coord(
                    attn_ctr[tt], attn_delta[tt], filter_height, filter_width)
                attn_box[tt] = base.get_filled_box_idx(
                    _idx_map, attn_top_left[tt], attn_bot_right[tt])
                attn_box[tt] = tf.reshape(attn_box[tt],
                                          [-1, 1, inp_height, inp_width])
            else:
                attn_box[tt] = base.extract_patch(const_ones * attn_box_gamma[tt],
                                                  filter_y_inv, filter_x_inv, 1)
                attn_box[tt] = tf.sigmoid(attn_box[tt] + attn_box_beta)
                attn_box[tt] = tf.reshape(attn_box[tt],
                                          [-1, 1, inp_height, inp_width])

            # Kick in GT bbox.
            if use_knob:
                # IOU [B, 1, T]
                if use_iou_box:
                    _top_left = tf.expand_dims(attn_top_left[tt], 1)
                    _bot_right = tf.expand_dims(attn_bot_right[tt], 1)

                    if not fixed_order:
                        iou_soft_box[tt] = base.f_iou_box(
                            _top_left, _bot_right, attn_top_left_gt,
                            attn_bot_right_gt)
                else:
                    if not fixed_order:
                        iou_soft_box[tt] = base.f_inter(
                            attn_box[tt], attn_box_gt) / \
                            base.f_union(attn_box[tt], attn_box_gt, eps=1e-5)

                if fixed_order:
                    attn_ctr_gtm = attn_ctr_gt_noise[:, tt, :]
                    attn_delta_gtm = attn_delta_gt_noise[:, tt, :]
                    attn_size_gtm = attn_size_gt_noise[:, tt, :]
                else:
                    grd_match = base.f_greedy_match(
                        iou_soft_box[tt], grd_match_cum)
                    # Let's try not using cumulative match.
                    # grd_match_cum += grd_match

                    # [B, T, 1]
                    grd_match = tf.expand_dims(grd_match, 2)
                    attn_ctr_gtm = tf.reduce_sum(
                        grd_match * attn_ctr_gt_noise, 1)
                    attn_delta_gtm = tf.reduce_sum(
                        grd_match * attn_delta_gt_noise, 1)
                    attn_size_gtm = tf.reduce_sum(
                        grd_match * attn_size_gt_noise, 1)

                _gt_knob_box = gt_knob_box
                attn_ctr[tt] = phase_train_f * _gt_knob_box[:, tt, 0: 1] * \
                    attn_ctr_gtm + \
                    (1 - phase_train_f * _gt_knob_box[:, tt, 0: 1]) * \
                    attn_ctr[tt]
                attn_delta[tt] = phase_train_f * _gt_knob_box[:, tt, 0: 1] * \
                    attn_delta_gtm + \
                    (1 - phase_train_f * _gt_knob_box[:, tt, 0: 1]) * \
                    attn_delta[tt]
                attn_size[tt] = phase_train_f * _gt_knob_box[:, tt, 0: 1] * \
                    attn_size_gtm + \
                    (1 - phase_train_f * _gt_knob_box[:, tt, 0: 1]) * \
                    attn_size[tt]

            attn_top_left[tt], attn_bot_right[tt] = _get_attn_coord(
                attn_ctr[tt], attn_delta[tt], filter_height, filter_width)

            filter_y = base.get_gaussian_filter(
                attn_ctr[tt][:, 0], attn_size[tt][:, 0],
                attn_lg_var[tt][:, 0], inp_height, filter_height)
            filter_x = base.get_gaussian_filter(
                attn_ctr[tt][:, 1], attn_size[tt][:, 1],
                attn_lg_var[tt][:, 1], inp_width, filter_width)
            filter_y_inv = tf.transpose(filter_y, [0, 2, 1])
            filter_x_inv = tf.transpose(filter_x, [0, 2, 1])

            # Attended patch [B, A, A, D]
            x_patch[tt] = attn_gamma[tt] * base.extract_patch(
                acnn_inp, filter_y, filter_x, acnn_inp_depth)

            # CNN [B, A, A, D] => [B, RH2, RW2, RD2]
            h_acnn[tt] = acnn(x_patch[tt])
            h_acnn_last[tt] = h_acnn[tt][-1]

            if use_attn_rnn:
                # RNN [B, T, R2]
                arnn_inp = tf.reshape(h_acnn_last[tt], [-1, arnn_inp_dim])
                arnn_state[tt], arnn_g_i[tt], arnn_g_f[tt], arnn_g_o[tt] = \
                    arnn_cell(arnn_inp, arnn_state[tt - 1])

            # Dense segmentation network [B, R] => [B, M]
            if use_attn_rnn:
                h_arnn = tf.slice(
                    arnn_state[tt], [0, arnn_dim], [-1, arnn_dim])
                amlp_inp = h_arnn
            else:
                amlp_inp = h_acnn_last[tt]
            amlp_inp = tf.reshape(amlp_inp, [-1, amlp_inp_dim])
            h_core = amlp(amlp_inp)[-1]
            h_core_img = tf.reshape(
                h_core, [-1, arnn_h, arnn_w, attn_mlp_depth])

            # DCNN
            skip = [None] + h_acnn[tt][::-1][1:] + [x_patch[tt]]
            h_adcnn[tt] = adcnn(h_core_img, skip=skip)

            # Output
            y_out[tt] = base.extract_patch(
                h_adcnn[tt][-1], filter_y_inv, filter_x_inv, 1)
            y_out[tt] = tf.exp(y_out_lg_gamma[tt]) * y_out[tt] + y_out_beta
            y_out[tt] = tf.sigmoid(y_out[tt])
            y_out[tt] = tf.reshape(y_out[tt], [-1, 1, inp_height, inp_width])

            # Scoring network
            smlp_inp = tf.concat(1, [h_crnn[tt], h_core])
            s_out[tt] = smlp(smlp_inp)[-1]

            # Here is the knob kick in GT segmentations at this timestep.
            # [B, N, 1, 1]
            if use_canvas:
                if use_knob:
                    _gt_knob_segm = tf.expand_dims(
                        tf.expand_dims(gt_knob_segm[:, tt, 0: 1], 2), 3)

                    if fixed_order:
                        _y_out = tf.expand_dims(y_gt[:, tt, :, :], 3)
                    else:
                        grd_match = tf.expand_dims(grd_match, 3)
                        _y_out = tf.expand_dims(tf.reduce_sum(
                            grd_match * y_gt, 1), 3)
                    # Add independent uniform noise to groundtruth.
                    _noise = tf.random_uniform(
                        tf.pack([num_ex, inp_height, inp_width, 1]), 0, 0.3)
                    _y_out = _y_out - _y_out * _noise
                    _y_out = phase_train_f * _gt_knob_segm * _y_out + \
                        (1 - phase_train_f * _gt_knob_segm) * \
                        tf.reshape(y_out[tt], [-1, inp_height, inp_width, 1])
                else:
                    _y_out = tf.reshape(y_out[tt],
                                        [-1, inp_height, inp_width, 1])
                canvas += tf.stop_gradient(_y_out)

#########################
# Model outputs
#########################
        s_out = tf.concat(1, s_out)
        model['s_out'] = s_out
        y_out = tf.concat(1, y_out)
        model['y_out'] = y_out
        attn_box = tf.concat(1, attn_box)
        model['attn_box'] = attn_box
        x_patch = tf.concat(1, [tf.expand_dims(x_patch[tt], 1)
                                for tt in xrange(timespan)])
        model['x_patch'] = x_patch

        attn_top_left = tf.concat(1, [tf.expand_dims(tmp, 1)
                                      for tmp in attn_top_left])
        attn_bot_right = tf.concat(1, [tf.expand_dims(tmp, 1)
                                       for tmp in attn_bot_right])
        attn_ctr = tf.concat(1, [tf.expand_dims(tmp, 1)
                                 for tmp in attn_ctr])
        attn_size = tf.concat(1, [tf.expand_dims(tmp, 1)
                                  for tmp in attn_size])
        attn_lg_gamma = tf.concat(1, [tf.expand_dims(tmp, 1)
                                      for tmp in attn_lg_gamma])
        attn_box_lg_gamma = tf.concat(1, [tf.expand_dims(tmp, 1)
                                          for tmp in attn_box_lg_gamma])
        y_out_lg_gamma = tf.concat(1, [tf.expand_dims(tmp, 1)
                                       for tmp in y_out_lg_gamma])
        model['attn_ctr'] = attn_ctr
        model['attn_size'] = attn_size
        model['attn_top_left'] = attn_top_left
        model['attn_bot_right'] = attn_bot_right
        model['attn_box_gt'] = attn_box_gt

#########################
# Loss function
#########################
        y_gt_shape = tf.shape(y_gt)
        num_ex_f = tf.to_float(y_gt_shape[0])
        max_num_obj = tf.to_float(y_gt_shape[1])

############################
# Box loss
############################
        if fixed_order:
            # [B, T] for fixed order.
            iou_soft_box = base.f_iou(
                attn_box, attn_box_gt, pairwise=False)
        else:
            if use_knob:
                # [B, T, T] for matching.
                iou_soft_box = tf.concat(
                    1, [tf.expand_dims(iou_soft_box[tt], 1)
                        for tt in xrange(timespan)])
            else:
                iou_soft_box = base.f_iou(attn_box, attn_box_gt,
                                          timespan, pairwise=True)

        identity_match = base.get_identity_match(num_ex, timespan, s_gt)
        if fixed_order:
            match_box = identity_match
        else:
            match_box = base.f_segm_match(iou_soft_box, s_gt)

        model['match_box'] = match_box
        match_sum_box = tf.reduce_sum(match_box, reduction_indices=[2])
        match_count_box = tf.reduce_sum(match_sum_box, reduction_indices=[1])
        match_count_box = tf.maximum(1.0, match_count_box)

        # [B] if fixed order, [B, T] if matching.
        if fixed_order:
            iou_soft_box_mask = iou_soft_box
        else:
            iou_soft_box_mask = tf.reduce_sum(
                iou_soft_box * match_box, [1])
        iou_soft_box = tf.reduce_sum(iou_soft_box_mask, [1])
        iou_soft_box = tf.reduce_sum(
            iou_soft_box / match_count_box) / num_ex_f

        if box_loss_fn == 'iou':
            box_loss = -iou_soft_box
        elif box_loss_fn == 'wt_cov':
            box_loss = -base.f_weighted_coverage(iou_soft_box, attn_box_gt)
        elif box_loss_fn == 'mse':
            box_loss_fn = base.f_match_loss(
                y_out, y_gt, match_box, timespan, f_mse)
        elif box_loss_fn == 'bce':
            box_loss_fn = base.f_match_loss(
                y_out, y_gt, match_box, timespan, f_bce)
        else:
            raise Exception('Unknown box_loss_fn: {}'.format(box_loss_fn))
        model['box_loss'] = box_loss

        box_loss_coeff = tf.constant(1.0)
        model['box_loss_coeff'] = box_loss_coeff
        tf.add_to_collection('losses', box_loss_coeff * box_loss)

##############################
# Segmentation loss
##############################
        # IoU (soft)
        iou_soft_pairwise = base.f_iou(y_out, y_gt, timespan, pairwise=True)
        real_match = base.f_segm_match(iou_soft_pairwise, s_gt)
        if fixed_order:
            iou_soft = base.f_iou(y_out, y_gt, pairwise=False)
            match = identity_match
        else:
            iou_soft = iou_soft_pairwise
            match = real_match
        model['match'] = match
        match_sum = tf.reduce_sum(match, reduction_indices=[2])
        match_count = tf.reduce_sum(match_sum, reduction_indices=[1])
        match_count = tf.maximum(1.0, match_count)

        # Weighted coverage (soft)
        wt_cov_soft = base.f_weighted_coverage(iou_soft_pairwise, y_gt)
        model['wt_cov_soft'] = wt_cov_soft
        unwt_cov_soft = base.f_unweighted_coverage(
            iou_soft_pairwise, match_count)
        model['unwt_cov_soft'] = unwt_cov_soft

        # [B] if fixed order, [B, T] if matching.
        if fixed_order:
            iou_soft_mask = iou_soft
        else:
            iou_soft_mask = tf.reduce_sum(iou_soft * match, [1])
        iou_soft = tf.reduce_sum(iou_soft_mask, [1])
        iou_soft = tf.reduce_sum(iou_soft / match_count) / num_ex_f
        model['iou_soft'] = iou_soft

        if segm_loss_fn == 'iou':
            segm_loss = -iou_soft
        elif segm_loss_fn == 'wt_cov':
            segm_loss = -wt_cov_soft
        elif segm_loss_fn == 'bce':
            segm_loss = f_match_bce(y_out, y_gt, match, timespan)
        else:
            raise Exception('Unknown segm_loss_fn: {}'.format(segm_loss_fn))
        model['segm_loss'] = segm_loss
        segm_loss_coeff = 1.0
        tf.add_to_collection('losses', segm_loss_coeff * segm_loss)

####################
# Score loss
####################
        conf_loss = base.f_conf_loss(s_out, match, timespan, use_cum_min=True)
        model['conf_loss'] = conf_loss
        tf.add_to_collection('losses', loss_mix_ratio * conf_loss)

####################
# Total loss
####################
        total_loss = tf.add_n(tf.get_collection(
            'losses'), name='total_loss')
        model['loss'] = total_loss

####################
# Optimizer
####################
        learn_rate = tf.train.exponential_decay(
            base_learn_rate, global_step, steps_per_learn_rate_decay,
            learn_rate_decay, staircase=True)
        model['learn_rate'] = learn_rate
        eps = 1e-7

        train_step = GradientClipOptimizer(
            tf.train.AdamOptimizer(learn_rate, epsilon=eps),
            clip=1.0).minimize(total_loss, global_step=global_step)
        model['train_step'] = train_step

####################
# Statistics
####################
        # Here statistics (hard measures) is always using matching.
        y_out_hard = tf.to_float(y_out > 0.5)
        iou_hard = base.f_iou(y_out_hard, y_gt, timespan, pairwise=True)
        wt_cov_hard = base.f_weighted_coverage(iou_hard, y_gt)
        model['wt_cov_hard'] = wt_cov_hard
        unwt_cov_hard = base.f_unweighted_coverage(iou_hard, match_count)
        model['unwt_cov_hard'] = unwt_cov_hard
        iou_hard_mask = tf.reduce_sum(iou_hard * real_match, [1])
        iou_hard = tf.reduce_sum(tf.reduce_sum(iou_hard_mask, [1]) /
                                 match_count) / num_ex_f
        model['iou_hard'] = iou_hard

        dice = base.f_dice(y_out_hard, y_gt, timespan, pairwise=True)
        dice = tf.reduce_sum(tf.reduce_sum(
            dice * real_match, reduction_indices=[1, 2]) / match_count) / num_ex_f
        model['dice'] = dice

        model['count_acc'] = _count_acc(s_out, s_gt)
        model['dic'] = _dic(s_out, s_gt, abs=False)
        model['dic_abs'] = _dic(s_out, s_gt, abs=True)

################################
# Controller output statistics
################################
        attn_lg_gamma_mean = tf.reduce_sum(attn_lg_gamma) / num_ex_f / timespan
        attn_box_lg_gamma_mean = tf.reduce_sum(
            attn_box_lg_gamma) / num_ex_f / timespan
        y_out_lg_gamma_mean = tf.reduce_sum(
            y_out_lg_gamma) / num_ex_f / timespan
        model['attn_lg_gamma_mean'] = attn_lg_gamma_mean
        model['attn_box_lg_gamma_mean'] = attn_box_lg_gamma_mean
        model['y_out_lg_gamma_mean'] = y_out_lg_gamma_mean

##################################
# Controller RNN gate statistics
##################################
        crnn_g_i = tf.concat(1, [tf.expand_dims(tmp, 1) for tmp in crnn_g_i])
        crnn_g_f = tf.concat(1, [tf.expand_dims(tmp, 1) for tmp in crnn_g_f])
        crnn_g_o = tf.concat(1, [tf.expand_dims(tmp, 1) for tmp in crnn_g_o])
        crnn_g_i_avg = tf.reduce_sum(
            crnn_g_i) / num_ex_f / timespan / ctrl_rnn_hid_dim
        crnn_g_f_avg = tf.reduce_sum(
            crnn_g_f) / num_ex_f / timespan / ctrl_rnn_hid_dim
        crnn_g_o_avg = tf.reduce_sum(
            crnn_g_o) / num_ex_f / timespan / ctrl_rnn_hid_dim
        model['crnn_g_i_avg'] = crnn_g_i_avg
        model['crnn_g_f_avg'] = crnn_g_f_avg
        model['crnn_g_o_avg'] = crnn_g_o_avg

    return model
def get_model(opt, device='/cpu:0'):
    """CNN -> -> RNN -> DCNN -> Instances"""
    model = {}
    timespan = opt['timespan']
    inp_height = opt['inp_height']
    inp_width = opt['inp_width']
    inp_depth = opt['inp_depth']
    padding = opt['padding']

    rnn_type = opt['rnn_type']
    cnn_filter_size = opt['cnn_filter_size']
    cnn_depth = opt['cnn_depth']
    dcnn_filter_size = opt['dcnn_filter_size']
    dcnn_depth = opt['dcnn_depth']
    conv_lstm_filter_size = opt['conv_lstm_filter_size']
    conv_lstm_hid_depth = opt['conv_lstm_hid_depth']
    rnn_hid_dim = opt['rnn_hid_dim']
    mlp_depth = opt['mlp_depth']
    wd = opt['weight_decay']
    segm_dense_conn = opt['segm_dense_conn']
    use_bn = opt['use_bn']
    use_deconv = opt['use_deconv']
    add_skip_conn = opt['add_skip_conn']
    score_use_core = opt['score_use_core']
    loss_mix_ratio = opt['loss_mix_ratio']
    base_learn_rate = opt['base_learn_rate']
    learn_rate_decay = opt['learn_rate_decay']
    steps_per_learn_rate_decay = opt['steps_per_learn_rate_decay']
    num_mlp_layers = opt['num_mlp_layers']
    mlp_dropout_ratio = opt['mlp_dropout']
    segm_loss_fn = opt['segm_loss_fn']
    clip_gradient = opt['clip_gradient']

    rnd_hflip = opt['rnd_hflip']
    rnd_vflip = opt['rnd_vflip']
    rnd_transpose = opt['rnd_transpose']
    rnd_colour = opt['rnd_colour']

    with tf.device(base.get_device_fn(device)):
        # Input image, [B, H, W, D]
        x = tf.placeholder('float', [None, inp_height, inp_width, inp_depth])

        # Whether in training stage, required for batch norm.
        phase_train = tf.placeholder('bool')

        # Groundtruth segmentation maps, [B, T, H, W]
        y_gt = tf.placeholder(
            'float', [None, timespan, inp_height, inp_width])

        # Groundtruth confidence score, [B, T]
        s_gt = tf.placeholder('float', [None, timespan])

        model['x'] = x
        model['phase_train'] = phase_train
        model['y_gt'] = y_gt
        model['s_gt'] = s_gt

        x_shape = tf.shape(x)
        num_ex = x_shape[0]

        # Random image transformation
        x, y_gt = img.random_transformation(
            x, y_gt, padding, phase_train,
            rnd_hflip=rnd_hflip, rnd_vflip=rnd_vflip,
            rnd_transpose=rnd_transpose, rnd_colour=rnd_colour)
        model['x_trans'] = x
        model['y_gt_trans'] = y_gt

        # CNN
        cnn_filters = cnn_filter_size
        cnn_nlayers = len(cnn_filters)
        cnn_channels = [inp_depth] + cnn_depth
        cnn_pool = [2] * cnn_nlayers
        cnn_act = [tf.nn.relu] * cnn_nlayers
        cnn_use_bn = [use_bn] * cnn_nlayers
        cnn = nn.cnn(cnn_filters, cnn_channels, cnn_pool, cnn_act, cnn_use_bn,
                     phase_train=phase_train, wd=wd, model=model)
        h_cnn = cnn(x)
        h_pool3 = h_cnn[-1]

        # RNN input size
        subsample = np.array(cnn_pool).prod()
        rnn_h = inp_height / subsample
        rnn_w = inp_width / subsample

        # Low-res segmentation depth
        core_depth = mlp_depth if segm_dense_conn else 1
        core_dim = rnn_h * rnn_w * core_depth
        rnn_state = [None] * (timespan + 1)

        # RNN
        if rnn_type == 'conv_lstm':
            rnn_depth = conv_lstm_hid_depth
            rnn_dim = rnn_h * rnn_w * rnn_depth
            conv_lstm_inp_depth = cnn_channels[-1]
            rnn_inp = h_pool3
            rnn_state[-1] = tf.zeros(tf.pack([num_ex,
                                              rnn_h, rnn_w, rnn_depth * 2]))
            rnn_cell = nn.conv_lstm(conv_lstm_inp_depth, rnn_depth,
                                    conv_lstm_filter_size, wd=wd)
        elif rnn_type == 'lstm' or rnn_type == 'gru':
            rnn_dim = rnn_hid_dim
            rnn_inp_dim = rnn_h * rnn_w * cnn_channels[-1]
            rnn_inp = tf.reshape(
                h_pool3, [-1, rnn_h * rnn_w * cnn_channels[-1]])
            if rnn_type == 'lstm':
                rnn_state[-1] = tf.zeros(tf.pack([num_ex, rnn_hid_dim * 2]))
                rnn_cell = nn.lstm(rnn_inp_dim, rnn_hid_dim, wd=wd)
            else:
                rnn_state[-1] = tf.zeros(tf.pack([num_ex, rnn_hid_dim]))
                rnn_cell = nn.gru(rnn_inp_dim, rnn_hid_dim, wd=wd)
        else:
            raise Exception('Unknown RNN type: {}'.format(rnn_type))

        for tt in xrange(timespan):
            rnn_state[tt], _gi, _gf, _go = rnn_cell(rnn_inp, rnn_state[tt - 1])

        if rnn_type == 'conv_lstm':
            h_rnn = [tf.slice(rnn_state[tt], [0, 0, 0, rnn_depth],
                              [-1, -1, -1, rnn_depth])
                     for tt in xrange(timespan)]
        elif rnn_type == 'lstm':
            h_rnn = [tf.slice(rnn_state[tt], [0, rnn_dim], [-1, rnn_dim])
                     for tt in xrange(timespan)]
        elif rnn_type == 'gru':
            h_rnn = state

        h_rnn_all = tf.concat(
            1, [tf.expand_dims(h_rnn[tt], 1) for tt in xrange(timespan)])

        # Core segmentation network.
        if segm_dense_conn:
            # Dense segmentation network
            h_rnn_all = tf.reshape(h_rnn_all, [-1, rnn_dim])
            mlp_dims = [rnn_dim] + [core_dim] * num_mlp_layers
            mlp_act = [tf.nn.relu] * num_mlp_layers
            mlp_dropout = [1.0 - mlp_dropout_ratio] * num_mlp_layers
            segm_mlp = nn.mlp(mlp_dims, mlp_act, mlp_dropout,
                              phase_train=phase_train, wd=wd)
            h_core = segm_mlp(h_rnn_all)[-1]
            h_core = tf.reshape(h_core, [-1, rnn_h, rnn_w, mlp_depth])
        else:
            # Convolutional segmentation netowrk
            w_segm_conv = nn.weight_variable([3, 3, rnn_depth, 1], wd=wd)
            b_segm_conv = nn.weight_variable([1], wd=wd)
            b_log_softmax = nn.weight_variable([1])
            h_rnn_all = tf.reshape(
                h_rnn_all, [-1, rnn_h, rnn_w, rnn_depth])
            h_core = tf.reshape(tf.log(tf.nn.softmax(tf.reshape(
                nn.conv2d(h_rnn_all, w_segm_conv) + b_segm_conv,
                [-1, rnn_h * rnn_w]))) + b_log_softmax,
                [-1, rnn_h, rnn_w, 1])

        # Deconv net to upsample
        if use_deconv:
            dcnn_filters = dcnn_filter_size
            dcnn_nlayers = len(dcnn_filters)
            dcnn_unpool = [2] * (dcnn_nlayers - 1) + [1]
            dcnn_act = [tf.nn.relu] * (dcnn_nlayers - 1) + [tf.sigmoid]
            if segm_dense_conn:
                dcnn_channels = [mlp_depth] + dcnn_depth
            else:
                dcnn_channels = [1] * (dcnn_nlayers + 1)
            dcnn_use_bn = [use_bn] * dcnn_nlayers

            skip = None
            skip_ch = None
            if add_skip_conn:
                skip, skip_ch = build_skip_conn(
                    cnn_channels, h_cnn, x, timespan)

            dcnn = nn.dcnn(dcnn_filters, dcnn_channels, dcnn_unpool, dcnn_act,
                           dcnn_use_bn, skip_ch=skip_ch,
                           phase_train=phase_train, wd=wd, model=model)
            h_dcnn = dcnn(h_core, skip=skip)
            y_out = tf.reshape(
                h_dcnn[-1], [-1, timespan, inp_height, inp_width])
        else:
            y_out = tf.reshape(
                tf.image.resize_bilinear(h_core, [inp_height, inp_wiidth]),
                [-1, timespan, inp_height, inp_width])

        model['y_out'] = y_out

        # Scoring network
        if score_use_core:
            # Use core network to predict score
            score_inp = h_core
            score_inp_shape = [-1, core_dim]
            score_inp = tf.reshape(score_inp, score_inp_shape)
            score_dim = core_dim
        else:
            # Use RNN hidden state to predict score
            score_inp = h_rnn_all
            if rnn_type == 'conv_lstm':
                score_inp_shape = [-1, rnn_h, rnn_w, rnn_depth]
                score_inp = tf.reshape(score_inp, score_inp_shape)
                score_maxpool = opt['score_maxpool']
                score_dim = rnn_h * rnn_w / (score_maxpool ** 2) * rnn_depth
                if score_maxpool > 1:
                    score_inp = nn.max_pool(score_inp, score_maxpool)
                score_inp = tf.reshape(score_inp, [-1, score_dim])
            else:
                score_inp_shape = [-1, rnn_dim]
                score_inp = tf.reshape(score_inp, score_inp_shape)
                score_dim = rnn_dim

        score_mlp = nn.mlp(dims=[score_dim, 1], act=[tf.sigmoid], wd=wd)
        s_out = score_mlp(score_inp)[-1]
        s_out = tf.reshape(s_out, [-1, timespan])
        model['s_out'] = s_out

        # Loss function
        global_step = tf.Variable(0.0)
        learn_rate = tf.train.exponential_decay(
            base_learn_rate, global_step, steps_per_learn_rate_decay,
            learn_rate_decay, staircase=True)
        model['learn_rate'] = learn_rate
        eps = 1e-7

        y_gt_shape = tf.shape(y_gt)
        num_ex = tf.to_float(y_gt_shape[0])
        max_num_obj = tf.to_float(y_gt_shape[1])

        # Pairwise IOU
        iou_soft = base.f_iou(y_out, y_gt, timespan, pairwise=True)

        # Matching
        match = base.f_segm_match(iou_soft, s_gt)
        model['match'] = match
        match_sum = tf.reduce_sum(match, reduction_indices=[2])
        match_count = tf.reduce_sum(match_sum, reduction_indices=[1])

        # Weighted coverage (soft)
        wt_cov_soft = base.f_weighted_coverage(iou_soft, y_gt)
        model['wt_cov_soft'] = wt_cov_soft
        unwt_cov_soft = base.f_unweighted_coverage(iou_soft, match_count)
        model['unwt_cov_soft'] = unwt_cov_soft

        # IOU (soft)
        iou_soft_mask = tf.reduce_sum(iou_soft * match, [1])
        iou_soft = tf.reduce_sum(tf.reduce_sum(iou_soft_mask, [1]) /
                                 match_count) / num_ex
        model['iou_soft'] = iou_soft
        gt_wt = coverage_weight(y_gt)
        wt_iou_soft = tf.reduce_sum(tf.reduce_sum(iou_soft_mask * gt_wt, [1]) /
                                    match_count) / num_ex
        model['wt_iou_soft'] = wt_iou_soft

        if segm_loss_fn == 'iou':
            segm_loss = -iou_soft
        elif segm_loss_fn == 'wt_iou':
            segm_loss = -wt_iou_soft
        elif segm_loss_fn == 'wt_cov':
            segm_loss = -wt_cov_soft
        elif segm_loss_fn == 'bce':
            segm_loss = base.f_match_loss(
                y_out, y_gt, match, timespan, base.f_bce)
        model['segm_loss'] = segm_loss
        conf_loss = base.f_conf_loss(s_out, match, timespan, use_cum_min=True)
        model['conf_loss'] = conf_loss
        loss = loss_mix_ratio * conf_loss + segm_loss
        model['loss'] = loss

        tf.add_to_collection('losses', loss)
        total_loss = tf.add_n(tf.get_collection(
            'losses'), name='total_loss')
        model['total_loss'] = total_loss

        train_step = GradientClipOptimizer(
            tf.train.AdamOptimizer(learn_rate, epsilon=eps),
            clip=clip_gradient).minimize(total_loss, global_step=global_step)
        model['train_step'] = train_step

        # Statistics
        # [B, M, N] * [B, M, N] => [B] * [B] => [1]
        y_out_hard = tf.to_float(y_out > 0.5)
        iou_hard = base.f_iou(y_out_hard, y_gt, timespan, pairwise=True)
        wt_cov_hard = base.f_weighted_coverage(iou_hard, y_gt)
        model['wt_cov_hard'] = wt_cov_hard
        unwt_cov_hard = base.f_unweighted_coverage(iou_hard, match_count)
        model['unwt_cov_hard'] = unwt_cov_hard
        # [B, T]
        iou_hard_mask = tf.reduce_sum(iou_hard * match, [1])
        iou_hard = base.f_iou(tf.to_float(y_out > 0.5),
                              y_gt, timespan, pairwise=True)
        iou_hard = tf.reduce_sum(tf.reduce_sum(
            iou_hard * match, reduction_indices=[1, 2]) / match_count) / num_ex
        model['iou_hard'] = iou_hard
        wt_iou_hard = tf.reduce_sum(tf.reduce_sum(iou_hard_mask * gt_wt, [1]) /
                                    match_count) / num_ex
        model['wt_iou_hard'] = wt_iou_hard

        dice = base.f_dice(y_out_hard, y_gt, timespan, pairwise=True)
        dice = tf.reduce_sum(tf.reduce_sum(dice * match, [1, 2]) /
                             match_count) / num_ex
        model['dice'] = dice

        model['count_acc'] = base.f_count_acc(s_out, s_gt)
        model['dic'] = base.f_dic(s_out, s_gt, abs=False)
        model['dic_abs'] = base.f_dic(s_out, s_gt, abs=True)

    return model