def build_tracking_model(opt, device='/cpu:0'): """ Given the T+1 sequence of input, return T sequence of output. """ model = {} rnn_seq_len = opt['rnn_seq_len'] cnn_filter_size = opt['cnn_filter_size'] cnn_num_filter = opt['cnn_num_filter'] cnn_pool_size = opt['cnn_pool_size'] num_channel = opt['img_channel'] use_bn = opt['use_batch_norm'] height = opt['img_height'] width = opt['img_width'] weight_decay = opt['weight_decay'] rnn_hidden_dim = opt['rnn_hidden_dim'] base_learn_rate = opt['base_learn_rate'] learn_rate_decay_step = opt['learn_rate_decay_step'] learn_rate_decay_rate = opt['learn_rate_decay_rate'] pretrain_model_filename = opt['pretrain_model_filename'] is_pretrain = opt['is_pretrain'] with tf.device(get_device_fn(device)): phase_train = tf.placeholder('bool') # input image [B, T+1, H, W, C] anneal_threshold = tf.placeholder(tf.float32, [1]) imgs = tf.placeholder( tf.float32, [None, rnn_seq_len + 1, height, width, num_channel]) img_shape = tf.shape(imgs) batch_size = img_shape[0] init_bbox = tf.placeholder(tf.float32, [None, 4]) init_rnn_state = tf.placeholder(tf.float32, [None, rnn_hidden_dim * 2]) gt_bbox = tf.placeholder(tf.float32, [None, rnn_seq_len + 1, 4]) gt_score = tf.placeholder(tf.float32, [None, rnn_seq_len + 1]) IOU_score = [None] * (rnn_seq_len + 1) IOU_score[0] = 1 model['imgs'] = imgs model['gt_bbox'] = gt_bbox model['gt_score'] = gt_score model['init_bbox'] = init_bbox model['init_rnn_state'] = init_rnn_state model['phase_train'] = phase_train model['anneal_threshold'] = anneal_threshold # define a CNN model cnn_filter = cnn_filter_size cnn_nlayer = len(cnn_filter) cnn_channel = [num_channel] + cnn_num_filter cnn_pool = cnn_pool_size cnn_act = [tf.nn.relu] * cnn_nlayer cnn_use_bn = [use_bn] * cnn_nlayer # load pretrained model if is_pretrain: h5f = h5py.File(pretrain_model_filename, 'r') # for key, value in h5f.iteritems(): # print key, value cnn_init_w = [{'w': h5f['cnn_w_{}'.format(ii)][:], 'b': h5f['cnn_b_{}'.format(ii)][:]} for ii in xrange(cnn_nlayer)] for ii in xrange(cnn_nlayer): for tt in xrange(3 * rnn_seq_len): for w in ['beta', 'gamma']: cnn_init_w[ii]['{}_{}'.format(w, tt)] = h5f[ 'cnn_{}_0_{}'.format(ii, w)][:] cnn_model = nn.cnn(cnn_filter, cnn_channel, cnn_pool, cnn_act, cnn_use_bn, phase_train=phase_train, wd=weight_decay, init_weights=cnn_init_w) # define a RNN(LSTM) model cnn_subsample = np.array(cnn_pool).prod() rnn_h = int(height / cnn_subsample) rnn_w = int(width / cnn_subsample) rnn_dim = cnn_channel[-1] cnn_out_dim = rnn_h * rnn_w * rnn_dim # input dimension of RNN rnn_inp_dim = cnn_out_dim * 3 rnn_state = [None] * (rnn_seq_len + 1) predict_bbox = [None] * (rnn_seq_len + 1) predict_score = [None] * (rnn_seq_len + 1) predict_bbox[0] = init_bbox predict_score[0] = 1 # rnn_state[-1] = tf.zeros(tf.pack([batch_size, rnn_hidden_dim * 2])) # rnn_state[-1] = tf.concat(1, [inverse_transform_box(gt_bbox[:, 0, :], # height, width), tf.zeros(tf.pack([batch_size, rnn_hidden_dim * 2 - 4]))]) rnn_state[-1] = init_rnn_state rnn_hidden_feat = [None] * rnn_seq_len rnn_cell = nn.lstm(rnn_inp_dim, rnn_hidden_dim, wd=weight_decay) # define two linear mapping MLPs: # RNN hidden state -> bbox # RNN hidden state -> score bbox_mlp_dims = [rnn_hidden_dim, 4] bbox_mlp_act = [None] bbox_mlp = nn.mlp(bbox_mlp_dims, bbox_mlp_act, add_bias=True, phase_train=phase_train, wd=weight_decay) score_mlp_dims = [rnn_hidden_dim, 1] score_mlp_act = [tf.sigmoid] score_mlp = nn.mlp(score_mlp_dims, score_mlp_act, add_bias=True, phase_train=phase_train, wd=weight_decay) # training through time for tt in xrange(rnn_seq_len): # extract global CNN feature map of the current frame h_cnn_global_now = cnn_model(imgs[:, tt, :, :, :]) cnn_global_feat_now = h_cnn_global_now[-1] cnn_global_feat_now = tf.stop_gradient( cnn_global_feat_now) # fix CNN during training model['cnn_global_feat_now'] = cnn_global_feat_now # extract ROI CNN feature map of the current frame use_pred_bbox = tf.to_float( tf.less(tf.random_uniform([1]), anneal_threshold)) x1, y1, x2, y2 = tf.split( 1, 4, use_pred_bbox * predict_bbox[tt] + (1 - use_pred_bbox) * gt_bbox[:, tt, :]) idx_map = get_idx_map(tf.pack([batch_size, height, width])) mask_map = get_filled_box_idx(idx_map, tf.concat( 1, [y1, x1]), tf.concat(1, [y2, x2])) ROI_img = [] for cc in xrange(num_channel): ROI_img.append(imgs[:, tt, :, :, cc] * mask_map) h_cnn_roi_now = cnn_model( tf.transpose(tf.pack(ROI_img), [1, 2, 3, 0])) cnn_roi_feat_now = h_cnn_roi_now[-1] cnn_roi_feat_now = tf.stop_gradient( cnn_roi_feat_now) # fix CNN during training model['cnn_roi_feat_now'] = cnn_roi_feat_now # extract global CNN feature map of the next frame h_cnn_global_next = cnn_model(imgs[:, tt + 1, :, :, :]) cnn_global_feat_next = h_cnn_global_next[-1] cnn_global_feat_next = tf.stop_gradient( cnn_global_feat_next) # fix CNN during training model['cnn_global_feat_next'] = cnn_global_feat_next # going through a RNN # RNN input = global CNN feat map + ROI CNN feat map rnn_input = tf.concat(1, [tf.reshape(cnn_global_feat_now, [-1, cnn_out_dim]), tf.reshape( cnn_roi_feat_now, [-1, cnn_out_dim]), tf.reshape(cnn_global_feat_next, [-1, cnn_out_dim])]) rnn_state[tt], _, _, _ = rnn_cell(rnn_input, rnn_state[tt - 1]) rnn_hidden_feat[tt] = tf.slice( rnn_state[tt], [0, rnn_hidden_dim], [-1, rnn_hidden_dim]) # predict bbox and score raw_predict_bbox = bbox_mlp(rnn_hidden_feat[tt])[0] predict_bbox[ tt + 1] = transform_box(raw_predict_bbox, height, width) predict_score[ tt + 1] = score_mlp(rnn_hidden_feat[tt])[-1] # compute IOU IOU_score[ tt + 1] = compute_IOU(predict_bbox[tt + 1], gt_bbox[:, tt + 1, :]) model['final_rnn_state'] = rnn_state[rnn_seq_len-1] # # [B, T, 4] # predict_bbox_reshape = tf.concat( # 1, [tf.expand_dims(tmp, 1) for tmp in predict_bbox[:-1]]) # # [B, T] # IOU_score = f_iou_box(predict_bbox_reshape[:, :, 0: 1], predict_bbox_reshape[ # :, :, 2: 3], gt_bbox[:, :, 0: 1], gt_bbox[:, :, 2: 3]) predict_bbox = tf.transpose(tf.pack(predict_bbox[1:]), [1, 0, 2]) model['IOU_score'] = tf.transpose(tf.pack(IOU_score[1:]), [1, 0, 2]) # model['IOU_score'] = IOU_score model['predict_bbox'] = predict_bbox model['predict_score'] = tf.transpose(tf.pack(predict_score[1:])) # compute IOU loss batch_size_f = tf.to_float(batch_size) rnn_seq_len_f = tf.to_float(rnn_seq_len) # IOU_loss = tf.reduce_sum(gt_score * (- tf.concat(1, IOU_score))) / (batch_size_f * rnn_seq_len_f) valid_seq_length = tf.reduce_sum(gt_score[:, 1:], [1]) valid_seq_length = tf.maximum(1.0, valid_seq_length) IOU_loss = gt_score[:, 1:] * (- tf.concat(1, IOU_score[1:])) # [B,T] => [B, 1] IOU_loss = tf.reduce_sum(IOU_loss, [1]) # [B, 1] IOU_loss /= valid_seq_length # [1] IOU_loss = tf.reduce_sum(IOU_loss) / batch_size_f # compute L2 loss # diff_bbox = gt_bbox[:, 1:, :] - predict_bbox # diff_x1 = diff_bbox[:, :, 0] / width # diff_y1 = diff_bbox[:, :, 1] / height # diff_x2 = diff_bbox[:, :, 2] / width # diff_y2 = diff_bbox[:, :, 3] / height # diff_bbox = tf.transpose( # tf.pack([diff_x1, diff_y1, diff_x2, diff_y2]), [1, 2, 0]) # L2_loss = tf.reduce_sum(diff_bbox * diff_bbox, [1, 2]) / 4 # L2_loss /= valid_seq_length # L2_loss = tf.reduce_sum(L2_loss) / batch_size_f # cross-entropy loss cross_entropy = -tf.reduce_sum(gt_score[:, 1:] * tf.log(tf.concat(1, predict_score[1:])) + ( 1 - gt_score[:, 1:]) * tf.log(1 - tf.concat(1, predict_score[1:]))) / (batch_size_f * rnn_seq_len_f) model['IOU_loss'] = IOU_loss # model['L2_loss'] = L2_loss model['CE_loss'] = cross_entropy global_step = tf.Variable(0.0) eps = 1e-7 learn_rate = tf.train.exponential_decay( base_learn_rate, global_step, learn_rate_decay_step, learn_rate_decay_rate, staircase=True) model['learn_rate'] = learn_rate train_step = GradientClipOptimizer( tf.train.AdamOptimizer(learn_rate, epsilon=eps), clip=1.0).minimize(IOU_loss + cross_entropy, global_step=global_step) model['train_step'] = train_step return model
def get_model(opt, device='/cpu:0'): model = {} inp_height = opt['inp_height'] inp_width = opt['inp_width'] inp_depth = opt['inp_depth'] cnn_filter_size = opt['cnn_filter_size'] cnn_depth = opt['cnn_depth'] cnn_pool = opt['cnn_pool'] mlp_dims = opt['mlp_dims'] mlp_dropout = opt['mlp_dropout'] wd = opt['weight_decay'] base_learn_rate = opt['base_learn_rate'] learn_rate_decay = opt['learn_rate_decay'] steps_per_learn_rate_decay = opt['steps_per_learn_rate_decay'] ############################ # Input definition ############################ with tf.device(get_device_fn(device)): x = tf.placeholder( 'float', [None, inp_height, inp_width, inp_depth], name='x') phase_train = tf.placeholder('bool', name='phase_train') y_gt = tf.placeholder('float', [None], name='y_gt') global_step = tf.Variable(0.0) ############################ # Feature CNN definition ############################ cnn_channels = [inp_depth] + cnn_depth cnn_nlayers = len(cnn_filter_size) cnn_use_bn = [True] * cnn_nlayers cnn_act = [tf.nn.relu] * cnn_nlayers cnn = nn.cnn(cnn_filter_size, cnn_channels, cnn_pool, cnn_act, cnn_use_bn, model=model, phase_train=phase_train, wd=wd, scope='cnn') subsample = np.array(cnn_pool).prod() cnn_h = inp_height / subsample cnn_w = inp_width / subsample # feat_dim = cnn_h * cnn_w * cnn_channels[-1] feat_dim = cnn_channels[-1] ############################ # MLP definition ############################ mlp_nlayers = len(mlp_dims) mlp_dims = [feat_dim] + mlp_dims mlp_dropout_keep = [1 - mlp_dropout] * mlp_nlayers mlp_act = [tf.nn.relu] * (mlp_nlayers - 1) + [tf.sigmoid] mlp = nn.mlp(mlp_dims, mlp_act, model=model, dropout_keep=mlp_dropout_keep, phase_train=phase_train) ############################ # Computation graph ############################ f = cnn(x) f = nn.avg_pool(f[-1], cnn_h) f = tf.reshape(f, [-1, feat_dim]) y_out = mlp(f)[-1] y_out = tf.reshape(y_out, [-1]) ############################ # Loss function ############################ num_ex = tf.shape(y_gt)[0] num_ex_f = tf.to_float(num_ex) bce = f_bce(y_out, y_gt) bce = tf.reduce_sum(bce) / num_ex_f tf.add_to_collection('losses', bce) total_loss = tf.add_n(tf.get_collection('losses'), name='total_loss') ############################ # Statistics ############################ y_out_thresh = tf.to_float(y_out > 0.5) acc = tf.reduce_sum( tf.to_float(tf.equal(y_out_thresh, y_gt))) / num_ex_f #################### # Optimizer #################### learn_rate = tf.train.exponential_decay( base_learn_rate, global_step, steps_per_learn_rate_decay, learn_rate_decay, staircase=True) eps = 1e-7 train_step = tf.train.AdamOptimizer(learn_rate, epsilon=eps).minimize( total_loss, global_step=global_step) ############################ # Computation nodes ############################ model['x'] = x model['y_gt'] = y_gt model['phase_train'] = phase_train model['y_out'] = y_out model['loss'] = total_loss model['acc'] = acc model['learn_rate'] = learn_rate model['train_step'] = train_step return model
def get_model(opt, is_training=True): """The attention model""" log = logger.get() model = {} timespan = opt['timespan'] inp_height = opt['inp_height'] inp_width = opt['inp_width'] inp_depth = opt['inp_depth'] padding = opt['padding'] filter_height = opt['filter_height'] filter_width = opt['filter_width'] ctrl_cnn_filter_size = opt['ctrl_cnn_filter_size'] ctrl_cnn_depth = opt['ctrl_cnn_depth'] ctrl_cnn_pool = opt['ctrl_cnn_pool'] ctrl_rnn_hid_dim = opt['ctrl_rnn_hid_dim'] num_ctrl_mlp_layers = opt['num_ctrl_mlp_layers'] ctrl_mlp_dim = opt['ctrl_mlp_dim'] attn_cnn_filter_size = opt['attn_cnn_filter_size'] attn_cnn_depth = opt['attn_cnn_depth'] attn_cnn_pool = opt['attn_cnn_pool'] attn_dcnn_filter_size = opt['attn_dcnn_filter_size'] attn_dcnn_depth = opt['attn_dcnn_depth'] attn_dcnn_pool = opt['attn_dcnn_pool'] mlp_dropout_ratio = opt['mlp_dropout'] attn_box_padding_ratio = opt['attn_box_padding_ratio'] wd = opt['weight_decay'] use_bn = opt['use_bn'] segm_loss_fn = opt['segm_loss_fn'] box_loss_fn = opt['box_loss_fn'] loss_mix_ratio = opt['loss_mix_ratio'] base_learn_rate = opt['base_learn_rate'] learn_rate_decay = opt['learn_rate_decay'] steps_per_learn_rate_decay = opt['steps_per_learn_rate_decay'] use_knob = opt['use_knob'] knob_base = opt['knob_base'] knob_decay = opt['knob_decay'] steps_per_knob_decay = opt['steps_per_knob_decay'] knob_box_offset = opt['knob_box_offset'] knob_segm_offset = opt['knob_segm_offset'] knob_use_timescale = opt['knob_use_timescale'] gt_box_ctr_noise = opt['gt_box_ctr_noise'] gt_box_pad_noise = opt['gt_box_pad_noise'] gt_segm_noise = opt['gt_segm_noise'] squash_ctrl_params = opt['squash_ctrl_params'] fixed_order = opt['fixed_order'] clip_gradient = opt['clip_gradient'] fixed_gamma = opt['fixed_gamma'] num_ctrl_rnn_iter = opt['num_ctrl_rnn_iter'] num_glimpse_mlp_layers = opt['num_glimpse_mlp_layers'] pretrain_ctrl_net = opt['pretrain_ctrl_net'] pretrain_attn_net = opt['pretrain_attn_net'] pretrain_net = opt['pretrain_net'] if 'freeze_ctrl_cnn' in opt: freeze_ctrl_cnn = opt['freeze_ctrl_cnn'] freeze_ctrl_rnn = opt['freeze_ctrl_rnn'] freeze_attn_net = opt['freeze_attn_net'] else: freeze_ctrl_cnn = True freeze_ctrl_rnn = True freeze_attn_net = True if 'freeze_ctrl_mlp' in opt: freeze_ctrl_mlp = opt['freeze_ctrl_mlp'] else: freeze_ctrl_mlp = freeze_ctrl_rnn if 'fixed_var' in opt: fixed_var = opt['fixed_var'] else: fixed_var = False if 'dynamic_var' in opt: dynamic_var = opt['dynamic_var'] else: dynamic_var = False if 'use_iou_box' in opt: use_iou_box = opt['use_iou_box'] else: use_iou_box = False if 'stop_canvas_grad' in opt: stop_canvas_grad = opt['stop_canvas_grad'] else: stop_canvas_grad = True if 'add_skip_conn' in opt: add_skip_conn = opt['add_skip_conn'] else: add_skip_conn = True if 'attn_cnn_skip' in opt: attn_cnn_skip = opt['attn_cnn_skip'] else: attn_cnn_skip = [add_skip_conn] * len(attn_cnn_filter_size) if 'disable_overwrite' in opt: disable_overwrite = opt['disable_overwrite'] else: disable_overwrite = True if 'add_d_out' in opt: add_d_out = opt['add_d_out'] add_y_out = opt['add_y_out'] else: add_d_out = False add_y_out = False if 'attn_add_d_out' in opt: attn_add_d_out = opt['attn_add_d_out'] attn_add_y_out = opt['attn_add_y_out'] attn_add_inp = opt['attn_add_inp'] attn_add_canvas = opt['attn_add_canvas'] else: attn_add_d_out = add_d_out attn_add_y_out = add_y_out attn_add_inp = True attn_add_canvas = True if 'ctrl_add_d_out' in opt: ctrl_add_d_out = opt['ctrl_add_d_out'] ctrl_add_y_out = opt['ctrl_add_y_out'] ctrl_add_inp = opt['ctrl_add_inp'] ctrl_add_canvas = opt['ctrl_add_canvas'] else: ctrl_add_d_out = add_d_out ctrl_add_y_out = add_y_out ctrl_add_inp = not ctrl_add_d_out ctrl_add_canvas = not ctrl_add_d_out if 'num_semantic_classes' in opt: num_semantic_classes = opt['num_semantic_classes'] else: num_semantic_classes = 1 rnd_hflip = opt['rnd_hflip'] rnd_vflip = opt['rnd_vflip'] rnd_transpose = opt['rnd_transpose'] rnd_colour = opt['rnd_colour'] ############################ # Input definition ############################ # Input image, [B, H, W, D] x = tf.placeholder('float', [None, inp_height, inp_width, inp_depth], name='x') x_shape = tf.shape(x) num_ex = x_shape[0] # Groundtruth segmentation, [B, T, H, W] y_gt = tf.placeholder('float', [None, timespan, inp_height, inp_width], name='y_gt') # Groundtruth confidence score, [B, T] s_gt = tf.placeholder('float', [None, timespan], name='s_gt') if add_d_out: d_in = tf.placeholder('float', [None, inp_height, inp_width, 8], name='d_in') model['d_in'] = d_in if add_y_out: y_in = tf.placeholder( 'float', [None, inp_height, inp_width, num_semantic_classes], name='y_in') model['y_in'] = y_in # Whether in training stage. phase_train = tf.placeholder('bool', name='phase_train') phase_train_f = tf.to_float(phase_train) model['x'] = x model['y_gt'] = y_gt model['s_gt'] = s_gt model['phase_train'] = phase_train # Global step if 'freeze_ctrl_cnn' in opt: global_step = tf.Variable(0.0, name='global_step') else: global_step = tf.Variable(0.0) ############################### # Random input transformation ############################### # Either add both or add nothing. assert (add_d_out and add_y_out) or (not add_d_out and not add_y_out) if not add_d_out: results = img.random_transformation(x, padding, phase_train, rnd_hflip=rnd_hflip, rnd_vflip=rnd_vflip, rnd_transpose=rnd_transpose, rnd_colour=rnd_colour, y=y_gt) x, y_gt = results['x'], results['y'] else: results = img.random_transformation(x, padding, phase_train, rnd_hflip=rnd_hflip, rnd_vflip=rnd_vflip, rnd_transpose=rnd_transpose, rnd_colour=rnd_colour, y=y_gt, d=d_in, c=y_in) x, y_gt, d_in, y_in = results['x'], results['y'], results[ 'd'], results['c'] model['d_in_trans'] = d_in model['y_in_trans'] = y_in model['x_trans'] = x model['y_gt_trans'] = y_gt ############################ # Canvas: external memory ############################ canvas = tf.zeros(tf.pack([num_ex, inp_height, inp_width, 1])) ccnn_inp_depth = 0 acnn_inp_depth = 0 if ctrl_add_inp: ccnn_inp_depth += inp_depth if ctrl_add_canvas: ccnn_inp_depth += 1 if attn_add_inp: acnn_inp_depth += inp_depth if attn_add_canvas: acnn_inp_depth += 1 if ctrl_add_d_out: ccnn_inp_depth += 8 if ctrl_add_y_out: ccnn_inp_depth += num_semantic_classes if attn_add_d_out: acnn_inp_depth += 8 if attn_add_y_out: acnn_inp_depth += num_semantic_classes ############################# # Controller CNN definition ############################# ccnn_filters = ctrl_cnn_filter_size ccnn_nlayers = len(ccnn_filters) acnn_nlayers = len(attn_cnn_filter_size) ccnn_channels = [ccnn_inp_depth] + ctrl_cnn_depth ccnn_pool = ctrl_cnn_pool ccnn_act = [tf.nn.relu] * ccnn_nlayers ccnn_use_bn = [use_bn] * ccnn_nlayers pt = pretrain_net or pretrain_ctrl_net if pt: log.info( 'Loading pretrained controller CNN weights from {}'.format(pt)) with h5py.File(pt, 'r') as h5f: ccnn_init_w = [{ 'w': h5f['ctrl_cnn_w_{}'.format(ii)][:], 'b': h5f['ctrl_cnn_b_{}'.format(ii)][:] } for ii in range(ccnn_nlayers)] for ii in range(ccnn_nlayers): for tt in range(timespan): for w in ['beta', 'gamma']: ccnn_init_w[ii]['{}_{}'.format( w, tt)] = h5f['ctrl_cnn_{}_{}_{}'.format(ii, tt, w)][:] ccnn_frozen = [freeze_ctrl_cnn] * ccnn_nlayers else: ccnn_init_w = None ccnn_frozen = [freeze_ctrl_cnn] * ccnn_nlayers ccnn = nn.cnn(ccnn_filters, ccnn_channels, ccnn_pool, ccnn_act, ccnn_use_bn, phase_train=phase_train, wd=wd, scope='ctrl_cnn', model=model, init_weights=ccnn_init_w, frozen=ccnn_frozen) h_ccnn = [None] * timespan ############################ # Controller RNN definition ############################ ccnn_subsample = np.array(ccnn_pool).prod() crnn_h = inp_height / ccnn_subsample crnn_w = inp_width / ccnn_subsample crnn_dim = ctrl_rnn_hid_dim canvas_dim = inp_height * inp_width / (ccnn_subsample**2) glimpse_map_dim = crnn_h * crnn_w glimpse_feat_dim = ccnn_channels[-1] crnn_inp_dim = glimpse_feat_dim pt = pretrain_net or pretrain_ctrl_net if pt: log.info( 'Loading pretrained controller RNN weights from {}'.format(pt)) with h5py.File(pt, 'r') as h5f: crnn_init_w = {} for w in [ 'w_xi', 'w_hi', 'b_i', 'w_xf', 'w_hf', 'b_f', 'w_xu', 'w_hu', 'b_u', 'w_xo', 'w_ho', 'b_o' ]: key = 'ctrl_lstm_{}'.format(w) crnn_init_w[w] = h5f[key][:] crnn_frozen = freeze_ctrl_rnn else: crnn_init_w = None crnn_frozen = freeze_ctrl_rnn crnn_state = [None] * (timespan + 1) crnn_glimpse_map = [None] * timespan crnn_g_i = [None] * timespan crnn_g_f = [None] * timespan crnn_g_o = [None] * timespan h_crnn = [None] * timespan crnn_state[-1] = tf.zeros(tf.pack([num_ex, crnn_dim * 2])) crnn_cell = nn.lstm(crnn_inp_dim, crnn_dim, wd=wd, scope='ctrl_lstm', init_weights=crnn_init_w, frozen=crnn_frozen, model=model) ############################ # Glimpse MLP definition ############################ gmlp_dims = [crnn_dim] * num_glimpse_mlp_layers + [glimpse_map_dim] gmlp_act = [tf.nn.relu] * \ (num_glimpse_mlp_layers - 1) + [tf.nn.softmax] gmlp_dropout = None pt = pretrain_net or pretrain_ctrl_net if pt: log.info('Loading pretrained glimpse MLP weights from {}'.format(pt)) with h5py.File(pt, 'r') as h5f: gmlp_init_w = [{ 'w': h5f['glimpse_mlp_w_{}'.format(ii)][:], 'b': h5f['glimpse_mlp_b_{}'.format(ii)][:] } for ii in range(num_glimpse_mlp_layers)] gmlp_frozen = [freeze_ctrl_rnn] * num_glimpse_mlp_layers else: gmlp_init_w = None gmlp_frozen = [freeze_ctrl_rnn] * num_glimpse_mlp_layers gmlp = nn.mlp(gmlp_dims, gmlp_act, add_bias=True, dropout_keep=gmlp_dropout, phase_train=phase_train, wd=wd, scope='glimpse_mlp', init_weights=gmlp_init_w, frozen=gmlp_frozen, model=model) ############################ # Controller MLP definition ############################ cmlp_dims = [crnn_dim] + [ctrl_mlp_dim] * \ (num_ctrl_mlp_layers - 1) + [9] cmlp_act = [tf.nn.relu] * (num_ctrl_mlp_layers - 1) + [None] cmlp_dropout = None pt = pretrain_net or pretrain_ctrl_net if pt: log.info( 'Loading pretrained controller MLP weights from {}'.format(pt)) with h5py.File(pt, 'r') as h5f: cmlp_init_w = [{ 'w': h5f['ctrl_mlp_w_{}'.format(ii)][:], 'b': h5f['ctrl_mlp_b_{}'.format(ii)][:] } for ii in range(num_ctrl_mlp_layers)] cmlp_frozen = [freeze_ctrl_mlp] * num_ctrl_mlp_layers else: cmlp_init_w = None cmlp_frozen = [freeze_ctrl_mlp] * num_ctrl_mlp_layers cmlp = nn.mlp(cmlp_dims, cmlp_act, add_bias=True, dropout_keep=cmlp_dropout, phase_train=phase_train, wd=wd, scope='ctrl_mlp', init_weights=cmlp_init_w, frozen=cmlp_frozen, model=model) ########################### # Attention CNN definition ########################### acnn_filters = attn_cnn_filter_size acnn_nlayers = len(acnn_filters) acnn_channels = [acnn_inp_depth] + attn_cnn_depth acnn_pool = attn_cnn_pool acnn_act = [tf.nn.relu] * acnn_nlayers acnn_use_bn = [use_bn] * acnn_nlayers pt = pretrain_net or pretrain_attn_net if pt: log.info('Loading pretrained attention CNN weights from {}'.format(pt)) with h5py.File(pt, 'r') as h5f: acnn_init_w = [{ 'w': h5f['attn_cnn_w_{}'.format(ii)][:], 'b': h5f['attn_cnn_b_{}'.format(ii)][:] } for ii in range(acnn_nlayers)] for ii in range(acnn_nlayers): for tt in range(timespan): for w in ['beta', 'gamma']: key = 'attn_cnn_{}_{}_{}'.format(ii, tt, w) acnn_init_w[ii]['{}_{}'.format(w, tt)] = h5f[key][:] acnn_frozen = [freeze_attn_net] * acnn_nlayers else: acnn_init_w = None acnn_frozen = [freeze_attn_net] * acnn_nlayers acnn = nn.cnn(acnn_filters, acnn_channels, acnn_pool, acnn_act, acnn_use_bn, phase_train=phase_train, wd=wd, scope='attn_cnn', model=model, init_weights=acnn_init_w, frozen=acnn_frozen) x_patch = [None] * timespan h_acnn = [None] * timespan h_acnn_last = [None] * timespan acnn_subsample = np.array(acnn_pool).prod() acnn_h = filter_height / acnn_subsample acnn_w = filter_width / acnn_subsample core_depth = acnn_channels[-1] core_dim = acnn_h * acnn_w * core_depth ########################## # Score MLP definition ########################## pt = pretrain_net if pt: log.info('Loading score mlp weights from {}'.format(pt)) with h5py.File(pt, 'r') as h5f: smlp_init_w = [{ 'w': h5f['score_mlp_w_{}'.format(ii)][:], 'b': h5f['score_mlp_b_{}'.format(ii)][:] } for ii in range(1)] else: smlp_init_w = None smlp = nn.mlp([crnn_dim + core_dim, 1], [tf.sigmoid], wd=wd, scope='score_mlp', init_weights=smlp_init_w, model=model) s_out = [None] * timespan ############################# # Attention DCNN definition ############################# adcnn_filters = attn_dcnn_filter_size adcnn_nlayers = len(adcnn_filters) adcnn_unpool = attn_dcnn_pool adcnn_act = [tf.nn.relu] * adcnn_nlayers adcnn_channels = [core_depth] + attn_dcnn_depth adcnn_bn_nlayers = adcnn_nlayers adcnn_use_bn = [use_bn] * adcnn_bn_nlayers + \ [False] * (adcnn_nlayers - adcnn_bn_nlayers) if add_skip_conn: adcnn_skip_ch = [0] adcnn_channels_rev = acnn_channels[::-1][1:] + [acnn_inp_depth] adcnn_skip_rev = attn_cnn_skip[::-1] for sk, ch in zip(adcnn_skip_rev, adcnn_channels_rev): adcnn_skip_ch.append(ch if sk else 0) pass else: adcnn_skip_ch = None pt = pretrain_net or pretrain_attn_net if pt: log.info( 'Loading pretrained attention DCNN weights from {}'.format(pt)) with h5py.File(pt, 'r') as h5f: adcnn_init_w = [{ 'w': h5f['attn_dcnn_w_{}'.format(ii)][:], 'b': h5f['attn_dcnn_b_{}'.format(ii)][:] } for ii in range(adcnn_nlayers)] for ii in range(adcnn_bn_nlayers): for tt in range(timespan): for w in ['beta', 'gamma']: key = 'attn_dcnn_{}_{}_{}'.format(ii, tt, w) adcnn_init_w[ii]['{}_{}'.format(w, tt)] = h5f[key][:] adcnn_frozen = [freeze_attn_net] * adcnn_nlayers else: adcnn_init_w = None adcnn_frozen = [freeze_attn_net] * adcnn_nlayers adcnn = nn.dcnn(adcnn_filters, adcnn_channels, adcnn_unpool, adcnn_act, use_bn=adcnn_use_bn, skip_ch=adcnn_skip_ch, phase_train=phase_train, wd=wd, model=model, init_weights=adcnn_init_w, frozen=adcnn_frozen, scope='attn_dcnn') h_adcnn = [None] * timespan ########################## # Attention box ########################## attn_ctr_norm = [None] * timespan attn_lg_size = [None] * timespan attn_ctr = [None] * timespan attn_size = [None] * timespan attn_lg_var = [None] * timespan attn_lg_gamma = [None] * timespan attn_gamma = [None] * timespan attn_box_lg_gamma = [None] * timespan attn_top_left = [None] * timespan attn_bot_right = [None] * timespan attn_box = [None] * timespan iou_soft_box = [None] * timespan const_ones = tf.ones(tf.pack([num_ex, filter_height, filter_width, 1])) attn_box_beta = tf.constant([-5.0]) attn_box_gamma = [None] * timespan ############################# # Groundtruth attention box ############################# # [B, T, 2] attn_ctr_gt, attn_size_gt, attn_lg_var_gt, attn_lg_gamma_gt, \ attn_box_gt, \ attn_top_left_gt, attn_bot_right_gt = \ modellib.get_gt_attn(y_gt, filter_height, filter_width, padding_ratio=attn_box_padding_ratio, center_shift_ratio=0.0, min_padding=padding + 4) attn_ctr_gt_noise, attn_size_gt_noise, attn_lg_var_gt_noise, \ attn_lg_gamma_gt_noise, \ attn_box_gt_noise, \ attn_top_left_gt_noise, attn_bot_right_gt_noise = \ modellib.get_gt_attn(y_gt, filter_height, filter_width, padding_ratio=tf.random_uniform( tf.pack([num_ex, timespan, 1]), attn_box_padding_ratio - gt_box_pad_noise, attn_box_padding_ratio + gt_box_pad_noise), center_shift_ratio=tf.random_uniform( tf.pack([num_ex, timespan, 2]), -gt_box_ctr_noise, gt_box_ctr_noise), min_padding=padding + 4) attn_ctr_norm_gt = modellib.get_normalized_center(attn_ctr_gt, inp_height, inp_width) attn_lg_size_gt = modellib.get_normalized_size(attn_size_gt, inp_height, inp_width) ########################## # Groundtruth mix ########################## grd_match_cum = tf.zeros(tf.pack([num_ex, timespan])) # Scale mix ratio on different timesteps. if knob_use_timescale: gt_knob_time_scale = tf.reshape( 1.0 + tf.log(1.0 + tf.to_float(tf.range(timespan)) * 3.0), [1, timespan, 1]) else: gt_knob_time_scale = tf.ones([1, timespan, 1]) # Mix in groundtruth box. global_step_box = tf.maximum(0.0, global_step - knob_box_offset) gt_knob_prob_box = tf.train.exponential_decay(knob_base, global_step_box, steps_per_knob_decay, knob_decay, staircase=False) gt_knob_prob_box = tf.minimum(1.0, gt_knob_prob_box * gt_knob_time_scale) gt_knob_box = tf.to_float( tf.random_uniform(tf.pack([num_ex, timespan, 1]), 0, 1.0) <= gt_knob_prob_box) model['gt_knob_prob_box'] = gt_knob_prob_box[0, 0, 0] # Mix in groundtruth segmentation. global_step_segm = tf.maximum(0.0, global_step - knob_segm_offset) gt_knob_prob_segm = tf.train.exponential_decay(knob_base, global_step_segm, steps_per_knob_decay, knob_decay, staircase=False) gt_knob_prob_segm = tf.minimum(1.0, gt_knob_prob_segm * gt_knob_time_scale) gt_knob_segm = tf.to_float( tf.random_uniform(tf.pack([num_ex, timespan, 1]), 0, 1.0) <= gt_knob_prob_segm) model['gt_knob_prob_segm'] = gt_knob_prob_segm[0, 0, 0] ########################## # Segmentation output ########################## y_out_patch = [None] * timespan y_out = [None] * timespan y_out_lg_gamma = [None] * timespan y_out_beta = tf.constant([-5.0]) ########################## # Computation graph ########################## for tt in range(timespan): # Controller CNN ccnn_inp_list = [] acnn_inp_list = [] if ctrl_add_inp: ccnn_inp_list.append(x) if attn_add_inp: acnn_inp_list.append(x) if ctrl_add_canvas: ccnn_inp_list.append(canvas) if attn_add_canvas: acnn_inp_list.append(canvas) if ctrl_add_d_out: ccnn_inp_list.append(d_in) if attn_add_d_out: acnn_inp_list.append(d_in) if ctrl_add_y_out: ccnn_inp_list.append(y_in) if attn_add_y_out: acnn_inp_list.append(y_in) acnn_inp = tf.concat(3, acnn_inp_list) ccnn_inp = tf.concat(3, ccnn_inp_list) h_ccnn[tt] = ccnn(ccnn_inp) _h_ccnn = h_ccnn[tt] h_ccnn_last = _h_ccnn[-1] # Controller RNN [B, R1] crnn_inp = tf.reshape(h_ccnn_last, [-1, glimpse_map_dim, glimpse_feat_dim]) crnn_state[tt] = [None] * (num_ctrl_rnn_iter + 1) crnn_g_i[tt] = [None] * num_ctrl_rnn_iter crnn_g_f[tt] = [None] * num_ctrl_rnn_iter crnn_g_o[tt] = [None] * num_ctrl_rnn_iter h_crnn[tt] = [None] * num_ctrl_rnn_iter crnn_state[tt][-1] = tf.zeros(tf.pack([num_ex, crnn_dim * 2])) crnn_glimpse_map[tt] = [None] * num_ctrl_rnn_iter crnn_glimpse_map[tt][0] = tf.ones(tf.pack([num_ex, glimpse_map_dim, 1 ])) / glimpse_map_dim # Inner glimpse RNN for tt2 in range(num_ctrl_rnn_iter): crnn_glimpse = tf.reduce_sum(crnn_inp * crnn_glimpse_map[tt][tt2], [1]) crnn_state[tt][tt2], crnn_g_i[tt][tt2], crnn_g_f[tt][tt2], \ crnn_g_o[tt][tt2] = crnn_cell( crnn_glimpse, crnn_state[tt][tt2 - 1]) h_crnn[tt][tt2] = tf.slice(crnn_state[tt][tt2], [0, crnn_dim], [-1, crnn_dim]) h_gmlp = gmlp(h_crnn[tt][tt2]) if tt2 < num_ctrl_rnn_iter - 1: crnn_glimpse_map[tt][tt2 + 1] = tf.expand_dims(h_gmlp[-1], 2) ctrl_out = cmlp(h_crnn[tt][-1])[-1] attn_ctr_norm[tt] = tf.slice(ctrl_out, [0, 0], [-1, 2]) attn_lg_size[tt] = tf.slice(ctrl_out, [0, 2], [-1, 2]) # Restrict to (-1, 1), (-inf, 0) if squash_ctrl_params: attn_ctr_norm[tt] = tf.tanh(attn_ctr_norm[tt]) attn_lg_size[tt] = -tf.nn.softplus(attn_lg_size[tt]) attn_ctr[tt], attn_size[tt] = modellib.get_unnormalized_attn( attn_ctr_norm[tt], attn_lg_size[tt], inp_height, inp_width) if fixed_var: attn_lg_var[tt] = tf.zeros(tf.pack([num_ex, 2])) else: attn_lg_var[tt] = modellib.get_normalized_var( attn_size[tt], filter_height, filter_width) if dynamic_var: attn_lg_var[tt] = tf.slice(ctrl_out, [0, 4], [-1, 2]) if fixed_gamma: attn_lg_gamma[tt] = tf.constant([0.0]) y_out_lg_gamma[tt] = tf.constant([2.0]) else: attn_lg_gamma[tt] = tf.slice(ctrl_out, [0, 6], [-1, 1]) y_out_lg_gamma[tt] = tf.slice(ctrl_out, [0, 8], [-1, 1]) attn_box_lg_gamma[tt] = tf.slice(ctrl_out, [0, 7], [-1, 1]) attn_gamma[tt] = tf.reshape(tf.exp(attn_lg_gamma[tt]), [-1, 1, 1, 1]) attn_box_gamma[tt] = tf.reshape(tf.exp(attn_box_lg_gamma[tt]), [-1, 1, 1, 1]) y_out_lg_gamma[tt] = tf.reshape(y_out_lg_gamma[tt], [-1, 1, 1, 1]) attn_top_left[tt], attn_bot_right[tt] = modellib.get_box_coord( attn_ctr[tt], attn_size[tt]) # Initial filters (predicted) filter_y = modellib.get_gaussian_filter(attn_ctr[tt][:, 0], attn_size[tt][:, 0], attn_lg_var[tt][:, 0], inp_height, filter_height) filter_x = modellib.get_gaussian_filter(attn_ctr[tt][:, 1], attn_size[tt][:, 1], attn_lg_var[tt][:, 1], inp_width, filter_width) filter_y_inv = tf.transpose(filter_y, [0, 2, 1]) filter_x_inv = tf.transpose(filter_x, [0, 2, 1]) # Attention box attn_box[tt] = modellib.extract_patch(const_ones * attn_box_gamma[tt], filter_y_inv, filter_x_inv, 1) attn_box[tt] = tf.sigmoid(attn_box[tt] + attn_box_beta) attn_box[tt] = tf.reshape(attn_box[tt], [-1, 1, inp_height, inp_width]) # Kick in GT bbox. if use_knob: if fixed_order: attn_ctr_gtm = attn_ctr_gt_noise[:, tt, :] # attn_delta_gtm = attn_delta_gt_noise[:, tt, :] attn_size_gtm = attn_size_gt_noise[:, tt, :] else: if use_iou_box: iou_soft_box[tt] = modellib.f_iou_box( tf.expand_dims(attn_top_left[tt], 1), tf.expand_dims(attn_bot_right[tt], 1), attn_top_left_gt, attn_bot_right_gt) else: iou_soft_box[tt] = modellib.f_inter( attn_box[tt], attn_box_gt) / \ modellib.f_union(attn_box[tt], attn_box_gt, eps=1e-5) grd_match = modellib.f_greedy_match(iou_soft_box[tt], grd_match_cum) # [B, T, 1] grd_match = tf.expand_dims(grd_match, 2) attn_ctr_gtm = tf.reduce_sum(grd_match * attn_ctr_gt_noise, 1) attn_size_gtm = tf.reduce_sum(grd_match * attn_size_gt_noise, 1) attn_ctr[tt] = phase_train_f * gt_knob_box[:, tt, 0: 1] * \ attn_ctr_gtm + \ (1 - phase_train_f * gt_knob_box[:, tt, 0: 1]) * \ attn_ctr[tt] attn_size[tt] = phase_train_f * gt_knob_box[:, tt, 0: 1] * \ attn_size_gtm + \ (1 - phase_train_f * gt_knob_box[:, tt, 0: 1]) * \ attn_size[tt] attn_top_left[tt], attn_bot_right[tt] = modellib.get_box_coord( attn_ctr[tt], attn_size[tt]) filter_y = modellib.get_gaussian_filter(attn_ctr[tt][:, 0], attn_size[tt][:, 0], attn_lg_var[tt][:, 0], inp_height, filter_height) filter_x = modellib.get_gaussian_filter(attn_ctr[tt][:, 1], attn_size[tt][:, 1], attn_lg_var[tt][:, 1], inp_width, filter_width) filter_y_inv = tf.transpose(filter_y, [0, 2, 1]) filter_x_inv = tf.transpose(filter_x, [0, 2, 1]) # Attended patch [B, A, A, D] x_patch[tt] = attn_gamma[tt] * modellib.extract_patch( acnn_inp, filter_y, filter_x, acnn_inp_depth) # CNN [B, A, A, D] => [B, RH2, RW2, RD2] h_acnn[tt] = acnn(x_patch[tt]) h_acnn_last[tt] = h_acnn[tt][-1] h_core = tf.reshape(h_acnn_last[tt], [-1, core_dim]) h_core_img = h_acnn_last[tt] # DCNN if add_skip_conn: h_acnn_rev = h_acnn[tt][::-1][1:] + [x_patch[tt]] adcnn_skip = [None] for sk, hcnn in zip(adcnn_skip_rev, h_acnn_rev): adcnn_skip.append(hcnn if sk else None) pass else: adcnn_skip = None h_adcnn[tt] = adcnn(h_core_img, skip=adcnn_skip) y_out_patch[tt] = tf.expand_dims(h_adcnn[tt][-1], 1) # Output y_out[tt] = modellib.extract_patch(h_adcnn[tt][-1], filter_y_inv, filter_x_inv, 1) y_out[tt] = tf.exp(y_out_lg_gamma[tt]) * y_out[tt] + y_out_beta y_out[tt] = tf.sigmoid(y_out[tt]) y_out[tt] = tf.reshape(y_out[tt], [-1, 1, inp_height, inp_width]) if disable_overwrite: y_out[tt] = tf.reshape(1 - canvas, [-1, 1, inp_height, inp_width]) * y_out[tt] # Scoring network smlp_inp = tf.concat(1, [h_crnn[tt][-1], h_core]) s_out[tt] = smlp(smlp_inp)[-1] # Here is the knob kick in GT segmentations at this timestep. # [B, N, 1, 1] if use_knob: _gt_knob_segm = tf.expand_dims( tf.expand_dims(gt_knob_segm[:, tt, 0:1], 2), 3) if fixed_order: _y_out = tf.expand_dims(y_gt[:, tt, :, :], 3) else: grd_match = tf.expand_dims(grd_match, 3) _y_out = tf.expand_dims(tf.reduce_sum(grd_match * y_gt, 1), 3) # Add independent uniform noise to groundtruth. _noise = tf.random_uniform( tf.pack([num_ex, inp_height, inp_width, 1]), 0, gt_segm_noise) _y_out = _y_out - _y_out * _noise _y_out = phase_train_f * _gt_knob_segm * _y_out + \ (1 - phase_train_f * _gt_knob_segm) * \ tf.reshape(y_out[tt], [-1, inp_height, inp_width, 1]) else: _y_out = tf.reshape(y_out[tt], [-1, inp_height, inp_width, 1]) y_out_last = _y_out canvas = tf.maximum(_y_out, canvas) if stop_canvas_grad: canvas = tf.stop_gradient(canvas) y_out_last = tf.stop_gradient(y_out_last) ######################### # Model outputs ######################### s_out = tf.concat(1, s_out) model['s_out'] = s_out y_out = tf.concat(1, y_out) model['y_out'] = y_out y_out_patch = tf.concat(1, y_out_patch) model['y_out_patch'] = y_out_patch attn_box = tf.concat(1, attn_box) model['attn_box'] = attn_box x_patch = tf.concat( 1, [tf.expand_dims(x_patch[tt], 1) for tt in range(timespan)]) model['x_patch'] = x_patch attn_top_left = tf.concat( 1, [tf.expand_dims(tmp, 1) for tmp in attn_top_left]) attn_bot_right = tf.concat( 1, [tf.expand_dims(tmp, 1) for tmp in attn_bot_right]) attn_ctr = tf.concat(1, [tf.expand_dims(tmp, 1) for tmp in attn_ctr]) attn_size = tf.concat(1, [tf.expand_dims(tmp, 1) for tmp in attn_size]) attn_lg_gamma = tf.concat( 1, [tf.expand_dims(tmp, 1) for tmp in attn_lg_gamma]) attn_box_lg_gamma = tf.concat( 1, [tf.expand_dims(tmp, 1) for tmp in attn_box_lg_gamma]) y_out_lg_gamma = tf.concat( 1, [tf.expand_dims(tmp, 1) for tmp in y_out_lg_gamma]) model['attn_ctr'] = attn_ctr model['attn_size'] = attn_size model['attn_top_left'] = attn_top_left model['attn_bot_right'] = attn_bot_right model['attn_ctr_gt'] = attn_ctr_gt model['attn_size_gt'] = attn_size_gt model['attn_top_left_gt'] = attn_top_left_gt model['attn_bot_right_gt'] = attn_bot_right_gt model['attn_box_gt'] = attn_box_gt attn_ctr_norm = tf.concat( 1, [tf.expand_dims(tmp, 1) for tmp in attn_ctr_norm]) attn_lg_size = tf.concat(1, [tf.expand_dims(tmp, 1) for tmp in attn_lg_size]) model['attn_ctr_norm'] = attn_ctr_norm model['attn_lg_size'] = attn_lg_size attn_params = tf.concat(2, [attn_ctr_norm, attn_lg_size]) attn_params_gt = tf.concat(2, [attn_ctr_norm_gt, attn_lg_size_gt]) #################### # Glimpse #################### # T * T2 * [H', W'] => [T, T2, H', W'] crnn_glimpse_map = tf.concat(1, [ tf.expand_dims( tf.concat(1, [ tf.expand_dims(crnn_glimpse_map[tt][tt2], 1) for tt2 in range(num_ctrl_rnn_iter) ]), 1) for tt in range(timespan) ]) crnn_glimpse_map = tf.reshape( crnn_glimpse_map, [-1, timespan, num_ctrl_rnn_iter, crnn_h, crnn_w]) model['ctrl_rnn_glimpse_map'] = crnn_glimpse_map model['global_step'] = global_step if not is_training: return model ######################### # Loss function ######################### num_ex_f = tf.to_float(x_shape[0]) max_num_obj = tf.to_float(timespan) ############################ # Box loss ############################ if fixed_order: # [B, T] for fixed order. iou_soft_box = modellib.f_iou(attn_box, attn_box_gt, pairwise=False) else: if use_knob: # [B, T, T] for matching. iou_soft_box = tf.concat(1, [ tf.expand_dims(iou_soft_box[tt], 1) for tt in range(timespan) ]) else: iou_soft_box = modellib.f_iou(attn_box, attn_box_gt, timespan, pairwise=True) # iou_soft_box = modellib.f_iou_pair_new(attn_box, attn_box_gt) identity_match = modellib.get_identity_match(num_ex, timespan, s_gt) if fixed_order: match_box = identity_match else: match_box = modellib.f_segm_match(iou_soft_box, s_gt) model['match_box'] = match_box match_sum_box = tf.reduce_sum(match_box, reduction_indices=[2]) match_count_box = tf.reduce_sum(match_sum_box, reduction_indices=[1]) match_count_box = tf.maximum(1.0, match_count_box) # [B] if fixed order, [B, T] if matching. if fixed_order: iou_soft_box_mask = iou_soft_box else: iou_soft_box_mask = tf.reduce_sum(iou_soft_box * match_box, [1]) iou_soft_box = tf.reduce_sum(iou_soft_box_mask, [1]) iou_soft_box = tf.reduce_sum(iou_soft_box / match_count_box) / num_ex_f if box_loss_fn == 'mse': box_loss = modellib.f_match_loss(attn_params, attn_params_gt, match_box, timespan, modellib.f_squared_err, model=model) elif box_loss_fn == 'huber': box_loss = modellib.f_match_loss(attn_params, attn_params_gt, match_box, timespan, modellib.f_huber) elif box_loss_fn == 'iou': box_loss = -iou_soft_box elif box_loss_fn == 'wt_cov': box_loss = -modellib.f_weighted_coverage(iou_soft_box, attn_box_gt) elif box_loss_fn == 'bce': box_loss_fn = modellib.f_match_loss(y_out, y_gt, match_box, timespan, f_bce) else: raise Exception('Unknown box_loss_fn: {}'.format(box_loss_fn)) model['box_loss'] = box_loss box_loss_coeff = tf.constant(1.0) model['box_loss_coeff'] = box_loss_coeff tf.add_to_collection('losses', box_loss_coeff * box_loss) ############################## # Segmentation loss ############################## # IoU (soft) iou_soft_pairwise = modellib.f_iou(y_out, y_gt, timespan, pairwise=True) real_match = modellib.f_segm_match(iou_soft_pairwise, s_gt) if fixed_order: iou_soft = modellib.f_iou(y_out, y_gt, pairwise=False) match = identity_match else: iou_soft = iou_soft_pairwise match = real_match model['match'] = match match_sum = tf.reduce_sum(match, reduction_indices=[2]) match_count = tf.reduce_sum(match_sum, reduction_indices=[1]) match_count = tf.maximum(1.0, match_count) # Weighted coverage (soft) wt_cov_soft = modellib.f_weighted_coverage(iou_soft_pairwise, y_gt) model['wt_cov_soft'] = wt_cov_soft unwt_cov_soft = modellib.f_unweighted_coverage(iou_soft_pairwise, match_count) model['unwt_cov_soft'] = unwt_cov_soft # [B] if fixed order, [B, T] if matching. if fixed_order: iou_soft_mask = iou_soft else: iou_soft_mask = tf.reduce_sum(iou_soft * match, [1]) iou_soft = tf.reduce_sum(iou_soft_mask, [1]) iou_soft = tf.reduce_sum(iou_soft / match_count) / num_ex_f model['iou_soft'] = iou_soft if segm_loss_fn == 'iou': segm_loss = -iou_soft elif segm_loss_fn == 'wt_cov': segm_loss = -wt_cov_soft elif segm_loss_fn == 'bce': segm_loss = f_match_bce(y_out, y_gt, match, timespan) else: raise Exception('Unknown segm_loss_fn: {}'.format(segm_loss_fn)) model['segm_loss'] = segm_loss segm_loss_coeff = tf.constant(1.0) tf.add_to_collection('losses', segm_loss_coeff * segm_loss) #################### # Score loss #################### conf_loss = modellib.f_conf_loss(s_out, match, timespan, use_cum_min=True) model['conf_loss'] = conf_loss tf.add_to_collection('losses', loss_mix_ratio * conf_loss) #################### # Total loss #################### total_loss = tf.add_n(tf.get_collection('losses'), name='total_loss') model['loss'] = total_loss #################### # Optimizer #################### learn_rate = tf.train.exponential_decay(base_learn_rate, global_step, steps_per_learn_rate_decay, learn_rate_decay, staircase=True) model['learn_rate'] = learn_rate eps = 1e-7 optimizer = tf.train.AdamOptimizer(learn_rate, epsilon=eps) gvs = optimizer.compute_gradients(total_loss) capped_gvs = [] for grad, var in gvs: if grad is not None: capped_gvs.append((tf.clip_by_value(grad, -1, 1), var)) else: capped_gvs.append((grad, var)) train_step = optimizer.apply_gradients(capped_gvs, global_step=global_step) model['train_step'] = train_step #################### # Statistics #################### # Here statistics (hard measures) is always using matching. y_out_hard = tf.to_float(y_out > 0.5) iou_hard = modellib.f_iou(y_out_hard, y_gt, timespan, pairwise=True) wt_cov_hard = modellib.f_weighted_coverage(iou_hard, y_gt) model['wt_cov_hard'] = wt_cov_hard unwt_cov_hard = modellib.f_unweighted_coverage(iou_hard, match_count) model['unwt_cov_hard'] = unwt_cov_hard iou_hard_mask = tf.reduce_sum(iou_hard * real_match, [1]) iou_hard = tf.reduce_sum( tf.reduce_sum(iou_hard_mask, [1]) / match_count) / num_ex_f model['iou_hard'] = iou_hard dice = modellib.f_dice(y_out_hard, y_gt, timespan, pairwise=True) dice = tf.reduce_sum(tf.reduce_sum( dice * real_match, reduction_indices=[1, 2]) / match_count) / \ num_ex_f model['dice'] = dice model['count_acc'] = modellib.f_count_acc(s_out, s_gt) model['dic'] = modellib.f_dic(s_out, s_gt, abs=False) model['dic_abs'] = modellib.f_dic(s_out, s_gt, abs=True) ################################ # Controller output statistics ################################ if fixed_gamma: attn_lg_gamma_mean = tf.constant([0.0]) attn_box_lg_gamma_mean = tf.constant([2.0]) y_out_lg_gamma_mean = tf.constant([2.0]) else: attn_lg_gamma_mean = tf.reduce_sum(attn_lg_gamma) / num_ex_f / timespan attn_box_lg_gamma_mean = tf.reduce_sum( attn_box_lg_gamma) / num_ex_f / timespan y_out_lg_gamma_mean = tf.reduce_sum( y_out_lg_gamma) / num_ex_f / timespan model['attn_lg_gamma_mean'] = attn_lg_gamma_mean model['attn_box_lg_gamma_mean'] = attn_box_lg_gamma_mean model['y_out_lg_gamma_mean'] = y_out_lg_gamma_mean return model
def get_model(opt): """The box model""" log = logger.get() model = {} timespan = opt['timespan'] inp_height = opt['inp_height'] inp_width = opt['inp_width'] inp_depth = opt['inp_depth'] padding = opt['padding'] filter_height = opt['filter_height'] filter_width = opt['filter_width'] ctrl_cnn_filter_size = opt['ctrl_cnn_filter_size'] ctrl_cnn_depth = opt['ctrl_cnn_depth'] ctrl_cnn_pool = opt['ctrl_cnn_pool'] ctrl_rnn_hid_dim = opt['ctrl_rnn_hid_dim'] num_ctrl_mlp_layers = opt['num_ctrl_mlp_layers'] ctrl_mlp_dim = opt['ctrl_mlp_dim'] attn_box_padding_ratio = opt['attn_box_padding_ratio'] wd = opt['weight_decay'] use_bn = opt['use_bn'] box_loss_fn = opt['box_loss_fn'] base_learn_rate = opt['base_learn_rate'] learn_rate_decay = opt['learn_rate_decay'] steps_per_learn_rate_decay = opt['steps_per_learn_rate_decay'] pretrain_cnn = opt['pretrain_cnn'] if 'pretrain_net' in opt: pretrain_net = opt['pretrain_net'] else: pretrain_net = None if 'freeze_pretrain_cnn' in opt: freeze_pretrain_cnn = opt['freeze_pretrain_cnn'] else: freeze_pretrain_cnn = True squash_ctrl_params = opt['squash_ctrl_params'] clip_gradient = opt['clip_gradient'] fixed_order = opt['fixed_order'] num_ctrl_rnn_iter = opt['num_ctrl_rnn_iter'] num_glimpse_mlp_layers = opt['num_glimpse_mlp_layers'] if 'fixed_var' in opt: fixed_var = opt['fixed_var'] else: fixed_var = True if 'use_iou_box' in opt: use_iou_box = opt['use_iou_box'] else: use_iou_box = False if 'dynamic_var' in opt: dynamic_var = opt['dynamic_var'] else: dynamic_var = False if 'num_semantic_classes' in opt: num_semantic_classes = opt['num_semantic_classes'] else: num_semantic_classes = 1 if 'add_d_out' in opt: add_d_out = opt['add_d_out'] add_y_out = opt['add_y_out'] else: add_d_out = False add_y_out = False rnd_hflip = opt['rnd_hflip'] rnd_vflip = opt['rnd_vflip'] rnd_transpose = opt['rnd_transpose'] rnd_colour = opt['rnd_colour'] ############################ # Input definition ############################ # Input image, [B, H, W, D] x = tf.placeholder( 'float', [None, inp_height, inp_width, inp_depth], name='x') x_shape = tf.shape(x) num_ex = x_shape[0] # Groundtruth segmentation, [B, T, H, W] y_gt = tf.placeholder( 'float', [None, timespan, inp_height, inp_width], name='y_gt') # Groundtruth confidence score, [B, T] s_gt = tf.placeholder('float', [None, timespan], name='s_gt') if add_d_out: d_in = tf.placeholder( 'float', [None, inp_height, inp_width, 8], name='d_in') model['d_in'] = d_in if add_y_out: y_in = tf.placeholder( 'float', [None, inp_height, inp_width, num_semantic_classes], name='y_in') model['y_in'] = y_in # Whether in training stage. phase_train = tf.placeholder('bool', name='phase_train') phase_train_f = tf.to_float(phase_train) model['x'] = x model['y_gt'] = y_gt model['s_gt'] = s_gt model['phase_train'] = phase_train # Global step global_step = tf.Variable(0.0, name='global_step') ############################### # Random input transformation ############################### # Either add both or add nothing. assert (add_d_out and add_y_out) or (not add_d_out and not add_y_out) if not add_d_out: results = img.random_transformation( x, padding, phase_train, rnd_hflip=rnd_hflip, rnd_vflip=rnd_vflip, rnd_transpose=rnd_transpose, rnd_colour=rnd_colour, y=y_gt) x, y_gt = results['x'], results['y'] else: results = img.random_transformation( x, padding, phase_train, rnd_hflip=rnd_hflip, rnd_vflip=rnd_vflip, rnd_transpose=rnd_transpose, rnd_colour=rnd_colour, y=y_gt, d=d_in, c=y_in) x, y_gt, d_in, y_in = results['x'], results['y'], results['d'], results['c'] model['d_in_trans'] = d_in model['y_in_trans'] = y_in model['x_trans'] = x model['y_gt_trans'] = y_gt ############################ # Canvas: external memory ############################ canvas = tf.zeros(tf.pack([num_ex, inp_height, inp_width, 1])) ccnn_inp_depth = inp_depth + 1 acnn_inp_depth = inp_depth + 1 if add_d_out: ccnn_inp_depth += 8 acnn_inp_depth += 8 if add_y_out: ccnn_inp_depth += num_semantic_classes acnn_inp_depth += num_semantic_classes ############################ # Controller CNN definition ############################ ccnn_filters = ctrl_cnn_filter_size ccnn_nlayers = len(ccnn_filters) ccnn_channels = [ccnn_inp_depth] + ctrl_cnn_depth ccnn_pool = ctrl_cnn_pool ccnn_act = [tf.nn.relu] * ccnn_nlayers ccnn_use_bn = [use_bn] * ccnn_nlayers pt = pretrain_net or pretrain_cnn if pt: log.info('Loading pretrained weights from {}'.format(pt)) with h5py.File(pt, 'r') as h5f: pt_cnn_nlayers = 0 # Assuming pt_cnn_nlayers is smaller than or equal to # ccnn_nlayers. for ii in range(ccnn_nlayers): if 'attn_cnn_w_{}'.format(ii) in h5f: cnn_prefix = 'attn_' log.info('Loading attn_cnn_w_{}'.format(ii)) log.info('Loading attn_cnn_b_{}'.format(ii)) pt_cnn_nlayers += 1 elif 'cnn_w_{}'.format(ii) in h5f: cnn_prefix = '' log.info('Loading cnn_w_{}'.format(ii)) log.info('Loading cnn_b_{}'.format(ii)) pt_cnn_nlayers += 1 elif 'ctrl_cnn_w_{}'.format(ii) in h5f: cnn_prefix = 'ctrl_' log.info('Loading ctrl_cnn_w_{}'.format(ii)) log.info('Loading ctrl_cnn_b_{}'.format(ii)) pt_cnn_nlayers += 1 ccnn_init_w = [{ 'w': h5f['{}cnn_w_{}'.format(cnn_prefix, ii)][:], 'b': h5f['{}cnn_b_{}'.format(cnn_prefix, ii)][:] } for ii in range(pt_cnn_nlayers)] for ii in range(pt_cnn_nlayers): for tt in range(timespan): for w in ['beta', 'gamma']: ccnn_init_w[ii]['{}_{}'.format(w, tt)] = h5f[ '{}cnn_{}_{}_{}'.format(cnn_prefix, ii, tt, w)][:] ccnn_frozen = [freeze_pretrain_cnn] * pt_cnn_nlayers for ii in range(pt_cnn_nlayers, ccnn_nlayers): ccnn_init_w.append(None) ccnn_frozen.append(False) else: ccnn_init_w = None ccnn_frozen = None ccnn = nn.cnn(ccnn_filters, ccnn_channels, ccnn_pool, ccnn_act, ccnn_use_bn, phase_train=phase_train, wd=wd, scope='ctrl_cnn', model=model, init_weights=ccnn_init_w, frozen=ccnn_frozen) h_ccnn = [None] * timespan ############################ # Controller RNN definition ############################ ccnn_subsample = np.array(ccnn_pool).prod() crnn_h = inp_height / ccnn_subsample crnn_w = inp_width / ccnn_subsample crnn_dim = ctrl_rnn_hid_dim canvas_dim = inp_height * inp_width / (ccnn_subsample**2) glimpse_map_dim = crnn_h * crnn_w glimpse_feat_dim = ccnn_channels[-1] crnn_inp_dim = glimpse_feat_dim pt = pretrain_net if pt: log.info('Loading pretrained controller RNN weights from {}'.format(pt)) h5f = h5py.File(pt, 'r') crnn_init_w = {} for w in [ 'w_xi', 'w_hi', 'b_i', 'w_xf', 'w_hf', 'b_f', 'w_xu', 'w_hu', 'b_u', 'w_xo', 'w_ho', 'b_o' ]: key = 'ctrl_lstm_{}'.format(w) crnn_init_w[w] = h5f[key][:] crnn_frozen = None else: crnn_init_w = None crnn_frozen = None crnn_state = [None] * (timespan + 1) crnn_glimpse_map = [None] * timespan crnn_g_i = [None] * timespan crnn_g_f = [None] * timespan crnn_g_o = [None] * timespan h_crnn = [None] * timespan crnn_state[-1] = tf.zeros(tf.pack([num_ex, crnn_dim * 2])) crnn_cell = nn.lstm( crnn_inp_dim, crnn_dim, wd=wd, scope='ctrl_lstm', init_weights=crnn_init_w, frozen=crnn_frozen, model=model) ############################ # Glimpse MLP definition ############################ gmlp_dims = [crnn_dim] * num_glimpse_mlp_layers + [glimpse_map_dim] gmlp_act = [tf.nn.relu] * \ (num_glimpse_mlp_layers - 1) + [tf.nn.softmax] gmlp_dropout = None pt = pretrain_net if pt: log.info('Loading pretrained glimpse MLP weights from {}'.format(pt)) h5f = h5py.File(pt, 'r') gmlp_init_w = [{ 'w': h5f['glimpse_mlp_w_{}'.format(ii)][:], 'b': h5f['glimpse_mlp_b_{}'.format(ii)][:] } for ii in range(num_glimpse_mlp_layers)] gmlp_frozen = None else: gmlp_init_w = None gmlp_frozen = None gmlp = nn.mlp(gmlp_dims, gmlp_act, add_bias=True, dropout_keep=gmlp_dropout, phase_train=phase_train, wd=wd, scope='glimpse_mlp', init_weights=gmlp_init_w, frozen=gmlp_frozen, model=model) ############################ # Controller MLP definition ############################ cmlp_dims = [crnn_dim] + [ctrl_mlp_dim] * \ (num_ctrl_mlp_layers - 1) + [9] cmlp_act = [tf.nn.relu] * (num_ctrl_mlp_layers - 1) + [None] cmlp_dropout = None pt = pretrain_net if pt: log.info('Loading pretrained controller MLP weights from {}'.format(pt)) h5f = h5py.File(pt, 'r') cmlp_init_w = [{ 'w': h5f['ctrl_mlp_w_{}'.format(ii)][:], 'b': h5f['ctrl_mlp_b_{}'.format(ii)][:] } for ii in range(num_ctrl_mlp_layers)] cmlp_frozen = None else: cmlp_init_w = None cmlp_frozen = None cmlp = nn.mlp(cmlp_dims, cmlp_act, add_bias=True, dropout_keep=cmlp_dropout, phase_train=phase_train, wd=wd, scope='ctrl_mlp', init_weights=cmlp_init_w, frozen=cmlp_frozen, model=model) ########################## # Score MLP definition ########################## pt = pretrain_net if pt: log.info('Loading score mlp weights from {}'.format(pt)) h5f = h5py.File(pt, 'r') smlp_init_w = [{ 'w': h5f['score_mlp_w_{}'.format(ii)][:], 'b': h5f['score_mlp_b_{}'.format(ii)][:] } for ii in range(1)] else: smlp_init_w = None smlp = nn.mlp([crnn_dim, num_semantic_classes], [None], wd=wd, scope='score_mlp', init_weights=smlp_init_w, model=model) s_out = [None] * timespan ########################## # Attention box ########################## attn_ctr_norm = [None] * timespan attn_lg_size = [None] * timespan attn_lg_var = [None] * timespan attn_ctr = [None] * timespan attn_size = [None] * timespan attn_top_left = [None] * timespan attn_bot_right = [None] * timespan attn_box = [None] * timespan attn_box_lg_gamma = [None] * timespan attn_box_gamma = [None] * timespan const_ones = tf.ones(tf.pack([num_ex, filter_height, filter_width, 1])) attn_box_beta = tf.constant([-5.0]) iou_soft_box = [None] * timespan ############################# # Groundtruth attention box ############################# attn_top_left_gt, attn_bot_right_gt, attn_box_gt = modellib.get_gt_box( y_gt, padding_ratio=attn_box_padding_ratio, center_shift_ratio=0.0) attn_ctr_gt, attn_size_gt = modellib.get_box_ctr_size(attn_top_left_gt, attn_bot_right_gt) attn_ctr_norm_gt = modellib.get_normalized_center(attn_ctr_gt, inp_height, inp_width) attn_lg_size_gt = modellib.get_normalized_size(attn_size_gt, inp_height, inp_width) ########################## # Groundtruth mix ########################## grd_match_cum = tf.zeros(tf.pack([num_ex, timespan])) ########################## # Computation graph ########################## for tt in range(timespan): # Controller CNN ccnn_inp_list = [x, canvas] if add_d_out: ccnn_inp_list.append(d_in) if add_y_out: ccnn_inp_list.append(y_in) ccnn_inp = tf.concat(3, ccnn_inp_list) acnn_inp = ccnn_inp h_ccnn[tt] = ccnn(ccnn_inp) _h_ccnn = h_ccnn[tt] h_ccnn_last = _h_ccnn[-1] # Controller RNN [B, R1] crnn_inp = tf.reshape(h_ccnn_last, [-1, glimpse_map_dim, glimpse_feat_dim]) crnn_state[tt] = [None] * (num_ctrl_rnn_iter + 1) crnn_g_i[tt] = [None] * num_ctrl_rnn_iter crnn_g_f[tt] = [None] * num_ctrl_rnn_iter crnn_g_o[tt] = [None] * num_ctrl_rnn_iter h_crnn[tt] = [None] * num_ctrl_rnn_iter crnn_state[tt][-1] = tf.zeros(tf.pack([num_ex, crnn_dim * 2])) crnn_glimpse_map[tt] = [None] * num_ctrl_rnn_iter crnn_glimpse_map[tt][0] = tf.ones(tf.pack([num_ex, glimpse_map_dim, 1 ])) / glimpse_map_dim # Inner glimpse RNN for tt2 in range(num_ctrl_rnn_iter): crnn_glimpse = tf.reduce_sum(crnn_inp * crnn_glimpse_map[tt][tt2], [1]) crnn_state[tt][tt2], crnn_g_i[tt][tt2], crnn_g_f[tt][tt2], \ crnn_g_o[tt][tt2] = \ crnn_cell(crnn_glimpse, crnn_state[tt][tt2 - 1]) h_crnn[tt][tt2] = tf.slice(crnn_state[tt][tt2], [0, crnn_dim], [-1, crnn_dim]) h_gmlp = gmlp(h_crnn[tt][tt2]) if tt2 < num_ctrl_rnn_iter - 1: crnn_glimpse_map[tt][tt2 + 1] = tf.expand_dims(h_gmlp[-1], 2) ctrl_out = cmlp(h_crnn[tt][-1])[-1] attn_ctr_norm[tt] = tf.slice(ctrl_out, [0, 0], [-1, 2]) attn_lg_size[tt] = tf.slice(ctrl_out, [0, 2], [-1, 2]) # Restrict to (-1, 1), (-inf, 0) if squash_ctrl_params: attn_ctr_norm[tt] = tf.tanh(attn_ctr_norm[tt]) attn_lg_size[tt] = -tf.nn.softplus(attn_lg_size[tt]) attn_ctr[tt], attn_size[tt] = modellib.get_unnormalized_attn( attn_ctr_norm[tt], attn_lg_size[tt], inp_height, inp_width) attn_box_lg_gamma[tt] = tf.slice(ctrl_out, [0, 7], [-1, 1]) if fixed_var: attn_lg_var[tt] = tf.zeros(tf.pack([num_ex, 2])) else: attn_lg_var[tt] = modellib.get_normalized_var(attn_size[tt], filter_height, filter_width) if dynamic_var: attn_lg_var[tt] = tf.slice(ctrl_out, [0, 4], [-1, 2]) attn_box_gamma[tt] = tf.reshape( tf.exp(attn_box_lg_gamma[tt]), [-1, 1, 1, 1]) attn_top_left[tt], attn_bot_right[tt] = modellib.get_box_coord( attn_ctr[tt], attn_size[tt]) # Initial filters (predicted) filter_y = modellib.get_gaussian_filter( attn_ctr[tt][:, 0], attn_size[tt][:, 0], attn_lg_var[tt][:, 0], inp_height, filter_height) filter_x = modellib.get_gaussian_filter( attn_ctr[tt][:, 1], attn_size[tt][:, 1], attn_lg_var[tt][:, 1], inp_width, filter_width) filter_y_inv = tf.transpose(filter_y, [0, 2, 1]) filter_x_inv = tf.transpose(filter_x, [0, 2, 1]) # Attention box attn_box[tt] = attn_box_gamma[tt] * modellib.extract_patch( const_ones, filter_y_inv, filter_x_inv, 1) attn_box[tt] = tf.sigmoid(attn_box[tt] + attn_box_beta) attn_box[tt] = tf.reshape(attn_box[tt], [-1, 1, inp_height, inp_width]) if fixed_order: _y_out = tf.expand_dims(y_gt[:, tt, :, :], 3) else: if use_iou_box: iou_soft_box[tt] = modellib.f_iou_box( tf.expand_dims(attn_top_left[tt], 1), tf.expand_dims(attn_bot_right[tt], 1), attn_top_left_gt, attn_bot_right_gt) else: iou_soft_box[tt] = modellib.f_inter( attn_box[tt], attn_box_gt) / \ modellib.f_union(attn_box[tt], attn_box_gt, eps=1e-5) grd_match = modellib.f_greedy_match(iou_soft_box[tt], grd_match_cum) grd_match = tf.expand_dims(tf.expand_dims(grd_match, 2), 3) _y_out = tf.expand_dims(tf.reduce_sum(grd_match * y_gt, 1), 3) # Add independent uniform noise to groundtruth. _noise = tf.random_uniform( tf.pack([num_ex, inp_height, inp_width, 1]), 0, 0.3) _y_out = _y_out - _y_out * _noise canvas = tf.stop_gradient(tf.maximum(_y_out, canvas)) # canvas += tf.stop_gradient(_y_out) # Scoring network s_out[tt] = smlp(h_crnn[tt][-1])[-1] if num_semantic_classes == 1: s_out[tt] = tf.sigmoid(s_out[tt]) else: s_out[tt] = tf.nn.softmax(s_out[tt]) ######################### # Model outputs ######################### s_out = tf.concat(1, [tf.expand_dims(tmp, 1) for tmp in s_out]) if num_semantic_classes == 1: s_out = s_out[:, :, 0] model['s_out'] = s_out attn_box = tf.concat(1, attn_box) model['attn_box'] = attn_box attn_top_left = tf.concat(1, [tf.expand_dims(tmp, 1) for tmp in attn_top_left]) attn_bot_right = tf.concat(1, [tf.expand_dims(tmp, 1) for tmp in attn_bot_right]) attn_ctr = tf.concat(1, [tf.expand_dims(tmp, 1) for tmp in attn_ctr]) attn_size = tf.concat(1, [tf.expand_dims(tmp, 1) for tmp in attn_size]) model['attn_top_left'] = attn_top_left model['attn_bot_right'] = attn_bot_right model['attn_ctr'] = attn_ctr model['attn_size'] = attn_size model['attn_ctr_norm_gt'] = attn_ctr_norm_gt model['attn_lg_size_gt'] = attn_lg_size_gt model['attn_top_left_gt'] = attn_top_left_gt model['attn_bot_right_gt'] = attn_bot_right_gt model['attn_box_gt'] = attn_box_gt attn_ctr_norm = tf.concat(1, [tf.expand_dims(tmp, 1) for tmp in attn_ctr_norm]) attn_lg_size = tf.concat(1, [tf.expand_dims(tmp, 1) for tmp in attn_lg_size]) model['attn_ctr_norm'] = attn_ctr_norm model['attn_lg_size'] = attn_lg_size attn_params = tf.concat(2, [attn_ctr_norm, attn_lg_size]) attn_params_gt = tf.concat(2, [attn_ctr_norm_gt, attn_lg_size_gt]) ######################### # Loss function ######################### y_gt_shape = tf.shape(y_gt) num_ex_f = tf.to_float(y_gt_shape[0]) max_num_obj = tf.to_float(y_gt_shape[1]) ############################ # Box loss ############################ if fixed_order: # [B, T] for fixed order. iou_soft_box = modellib.f_iou(attn_box, attn_box_gt, pairwise=False) else: # [B, T, T] for matching. iou_soft_box = tf.concat( 1, [tf.expand_dims(iou_soft_box[tt], 1) for tt in range(timespan)]) identity_match = modellib.get_identity_match(num_ex, timespan, s_gt) if fixed_order: match_box = identity_match else: match_box = modellib.f_segm_match(iou_soft_box, s_gt) model['match_box'] = match_box match_sum_box = tf.reduce_sum(match_box, reduction_indices=[2]) match_count_box = tf.reduce_sum(match_sum_box, reduction_indices=[1]) match_count_box = tf.maximum(1.0, match_count_box) # [B] if fixed order, [B, T] if matching. if fixed_order: iou_soft_box_mask = iou_soft_box else: iou_soft_box_mask = tf.reduce_sum(iou_soft_box * match_box, [1]) iou_soft_box = tf.reduce_sum(iou_soft_box_mask, [1]) iou_soft_box = tf.reduce_sum(iou_soft_box / match_count_box) / num_ex_f if box_loss_fn == 'mse': box_loss = modellib.f_match_loss( attn_params, attn_params_gt, match_box, timespan, modellib.f_squared_err, model=model) elif box_loss_fn == 'huber': box_loss = modellib.f_match_loss(attn_params, attn_params_gt, match_box, timespan, modellib.f_huber) if box_loss_fn == 'iou': box_loss = -iou_soft_box elif box_loss_fn == 'wt_iou': box_loss = -wt_iou_soft_box elif box_loss_fn == 'wt_cov': box_loss = -modellib.f_weighted_coverage(iou_soft_box, box_map_gt) elif box_loss_fn == 'bce': box_loss = modellib.f_match_loss(box_map, box_map_gt, match_box, timespan, modellib.f_bce) else: raise Exception('Unknown box_loss_fn: {}'.format(box_loss_fn)) model['box_loss'] = box_loss box_loss_coeff = tf.constant(1.0) model['box_loss_coeff'] = box_loss_coeff tf.add_to_collection('losses', box_loss_coeff * box_loss) #################### # Score loss #################### if num_semantic_classes == 1: conf_loss = modellib.f_conf_loss( s_out, match_box, timespan, use_cum_min=True) else: conf_loss = modellib.f_conf_loss( 1 - s_out[:, :, 0], match_box, timespan, use_cum_min=True) model['conf_loss'] = conf_loss conf_loss_coeff = tf.constant(1.0) tf.add_to_collection('losses', conf_loss_coeff * conf_loss) #################### # Total loss #################### total_loss = tf.add_n(tf.get_collection('losses'), name='total_loss') model['loss'] = total_loss #################### # Optimizer #################### learn_rate = tf.train.exponential_decay( base_learn_rate, global_step, steps_per_learn_rate_decay, learn_rate_decay, staircase=True) model['learn_rate'] = learn_rate eps = 1e-7 optim = tf.train.AdamOptimizer(learn_rate, epsilon=eps) gvs = optim.compute_gradients(total_loss) capped_gvs = [] for grad, var in gvs: if grad is not None: capped_gvs.append((tf.clip_by_value(grad, -1, 1), var)) else: capped_gvs.append((grad, var)) train_step = optim.apply_gradients(capped_gvs, global_step=global_step) model['train_step'] = train_step #################### # Glimpse #################### # T * T2 * [B, H' * W'] => [B, T, T2, H', W'] crnn_glimpse_map = tf.concat(1, [ tf.expand_dims( tf.concat(1, [ tf.expand_dims(crnn_glimpse_map[tt][tt2], 1) for tt2 in range(num_ctrl_rnn_iter) ]), 1) for tt in range(timespan) ]) crnn_glimpse_map = tf.reshape( crnn_glimpse_map, [-1, timespan, num_ctrl_rnn_iter, crnn_h, crnn_w]) model['ctrl_rnn_glimpse_map'] = crnn_glimpse_map return model
def get_model(opt, device='/cpu:0'): """The attention model""" model = {} timespan = opt['timespan'] inp_height = opt['inp_height'] inp_width = opt['inp_width'] inp_depth = opt['inp_depth'] padding = opt['padding'] filter_height = opt['filter_height'] filter_width = opt['filter_width'] ctrl_cnn_filter_size = opt['ctrl_cnn_filter_size'] ctrl_cnn_depth = opt['ctrl_cnn_depth'] ctrl_cnn_pool = opt['ctrl_cnn_pool'] ctrl_rnn_hid_dim = opt['ctrl_rnn_hid_dim'] num_ctrl_mlp_layers = opt['num_ctrl_mlp_layers'] ctrl_mlp_dim = opt['ctrl_mlp_dim'] attn_cnn_filter_size = opt['attn_cnn_filter_size'] attn_cnn_depth = opt['attn_cnn_depth'] attn_cnn_pool = opt['attn_cnn_pool'] attn_dcnn_filter_size = opt['attn_dcnn_filter_size'] attn_dcnn_depth = opt['attn_dcnn_depth'] attn_dcnn_pool = opt['attn_dcnn_pool'] mlp_dropout_ratio = opt['mlp_dropout'] num_attn_mlp_layers = opt['num_attn_mlp_layers'] attn_mlp_depth = opt['attn_mlp_depth'] attn_box_padding_ratio = opt['attn_box_padding_ratio'] wd = opt['weight_decay'] use_bn = opt['use_bn'] segm_loss_fn = opt['segm_loss_fn'] box_loss_fn = opt['box_loss_fn'] loss_mix_ratio = opt['loss_mix_ratio'] base_learn_rate = opt['base_learn_rate'] learn_rate_decay = opt['learn_rate_decay'] steps_per_learn_rate_decay = opt['steps_per_learn_rate_decay'] use_knob = opt['use_knob'] knob_base = opt['knob_base'] knob_decay = opt['knob_decay'] steps_per_knob_decay = opt['steps_per_knob_decay'] knob_box_offset = opt['knob_box_offset'] knob_segm_offset = opt['knob_segm_offset'] knob_use_timescale = opt['knob_use_timescale'] gt_box_ctr_noise = opt['gt_box_ctr_noise'] gt_box_pad_noise = opt['gt_box_pad_noise'] gt_segm_noise = opt['gt_segm_noise'] squash_ctrl_params = opt['squash_ctrl_params'] fixed_order = opt['fixed_order'] clip_gradient = opt['clip_gradient'] fixed_gamma = opt['fixed_gamma'] ctrl_rnn_inp_struct = opt['ctrl_rnn_inp_struct'] # dense or attn num_ctrl_rnn_iter = opt['num_ctrl_rnn_iter'] num_glimpse_mlp_layers = opt['num_glimpse_mlp_layers'] pretrain_ctrl_net = opt['pretrain_ctrl_net'] pretrain_attn_net = opt['pretrain_attn_net'] pretrain_net = opt['pretrain_net'] # freeze_ctrl_cnn = True # freeze_ctrl_rnn = True # freeze_attn_net = True freeze_ctrl_cnn = opt['freeze_ctrl_cnn'] freeze_ctrl_rnn = opt['freeze_ctrl_rnn'] freeze_attn_net = opt['freeze_attn_net'] rnd_hflip = opt['rnd_hflip'] rnd_vflip = opt['rnd_vflip'] rnd_transpose = opt['rnd_transpose'] rnd_colour = opt['rnd_colour'] ############################ # Input definition ############################ with tf.device(base.get_device_fn(device)): # Input image, [B, H, W, D] x = tf.placeholder('float', [None, inp_height, inp_width, inp_depth], name='x') x_shape = tf.shape(x) num_ex = x_shape[0] # Groundtruth segmentation, [B, T, H, W] y_gt = tf.placeholder('float', [None, timespan, inp_height, inp_width], name='y_gt') # Groundtruth confidence score, [B, T] s_gt = tf.placeholder('float', [None, timespan], name='s_gt') # Whether in training stage. phase_train = tf.placeholder('bool', name='phase_train') phase_train_f = tf.to_float(phase_train) model['x'] = x model['y_gt'] = y_gt model['s_gt'] = s_gt model['phase_train'] = phase_train # Global step global_step = tf.Variable(0.0, name='global_step') # global_step = tf.Variable(0.0) ############################### # Random input transformation ############################### x, y_gt = img.random_transformation( x, y_gt, padding, phase_train, rnd_hflip=rnd_hflip, rnd_vflip=rnd_vflip, rnd_transpose=rnd_transpose, rnd_colour=rnd_colour) model['x_trans'] = x model['y_gt_trans'] = y_gt ############################ # Canvas: external memory ############################ canvas = tf.zeros(tf.pack([num_ex, inp_height, inp_width, 1])) ccnn_inp_depth = inp_depth + 1 acnn_inp_depth = inp_depth + 1 ############################ # Controller CNN definition ############################ ccnn_filters = ctrl_cnn_filter_size ccnn_nlayers = len(ccnn_filters) acnn_nlayers = len(attn_cnn_filter_size) ccnn_channels = [ccnn_inp_depth] + ctrl_cnn_depth ccnn_pool = ctrl_cnn_pool ccnn_act = [tf.nn.relu] * ccnn_nlayers ccnn_use_bn = [use_bn] * ccnn_nlayers pt = pretrain_net or pretrain_ctrl_net if pt: log.info('Loading pretrained controller CNN weights from {}'.format( pt)) h5f = h5py.File(pt, 'r') ccnn_init_w = [{'w': h5f['ctrl_cnn_w_{}'.format(ii)][:], 'b': h5f['ctrl_cnn_b_{}'.format(ii)][:]} for ii in xrange(ccnn_nlayers)] for ii in xrange(ccnn_nlayers): for tt in xrange(timespan): for w in ['beta', 'gamma']: ccnn_init_w[ii]['{}_{}'.format(w, tt)] = h5f[ 'ctrl_cnn_{}_{}_{}'.format(ii, tt, w)][:] ccnn_frozen = [freeze_ctrl_cnn] * ccnn_nlayers else: ccnn_init_w = None ccnn_frozen = [freeze_ctrl_cnn] * ccnn_nlayers ccnn = nn.cnn(ccnn_filters, ccnn_channels, ccnn_pool, ccnn_act, ccnn_use_bn, phase_train=phase_train, wd=wd, scope='ctrl_cnn', model=model, init_weights=ccnn_init_w, frozen=ccnn_frozen) h_ccnn = [None] * timespan ############################ # Controller RNN definition ############################ ccnn_subsample = np.array(ccnn_pool).prod() crnn_h = inp_height / ccnn_subsample crnn_w = inp_width / ccnn_subsample crnn_dim = ctrl_rnn_hid_dim canvas_dim = inp_height * inp_width / (ccnn_subsample ** 2) glimpse_map_dim = crnn_h * crnn_w glimpse_feat_dim = ccnn_channels[-1] if ctrl_rnn_inp_struct == 'dense': crnn_inp_dim = crnn_h * crnn_w * ccnn_channels[-1] elif ctrl_rnn_inp_struct == 'attn': crnn_inp_dim = glimpse_feat_dim pt = pretrain_net or pretrain_ctrl_net if pt: log.info('Loading pretrained controller RNN weights from {}'.format( pt)) h5f = h5py.File(pt, 'r') crnn_init_w = {} for w in ['w_xi', 'w_hi', 'b_i', 'w_xf', 'w_hf', 'b_f', 'w_xu', 'w_hu', 'b_u', 'w_xo', 'w_ho', 'b_o']: key = 'ctrl_lstm_{}'.format(w) crnn_init_w[w] = h5f[key][:] crnn_frozen = freeze_ctrl_rnn else: crnn_init_w = None crnn_frozen = freeze_ctrl_rnn crnn_state = [None] * (timespan + 1) crnn_glimpse_map = [None] * timespan crnn_g_i = [None] * timespan crnn_g_f = [None] * timespan crnn_g_o = [None] * timespan h_crnn = [None] * timespan crnn_state[-1] = tf.zeros(tf.pack([num_ex, crnn_dim * 2])) crnn_cell = nn.lstm(crnn_inp_dim, crnn_dim, wd=wd, scope='ctrl_lstm', init_weights=crnn_init_w, frozen=crnn_frozen, model=model) ############################ # Glimpse MLP definition ############################ gmlp_dims = [crnn_dim] * num_glimpse_mlp_layers + [glimpse_map_dim] gmlp_act = [tf.nn.relu] * \ (num_glimpse_mlp_layers - 1) + [tf.nn.softmax] gmlp_dropout = None pt = pretrain_net or pretrain_ctrl_net if pt: log.info('Loading pretrained glimpse MLP weights from {}'.format( pt)) h5f = h5py.File(pt, 'r') gmlp_init_w = [{'w': h5f['glimpse_mlp_w_{}'.format(ii)][:], 'b': h5f['glimpse_mlp_b_{}'.format(ii)][:]} for ii in xrange(num_glimpse_mlp_layers)] gmlp_frozen = [freeze_ctrl_rnn] * num_glimpse_mlp_layers else: gmlp_init_w = None gmlp_frozen = [freeze_ctrl_rnn] * num_glimpse_mlp_layers gmlp = nn.mlp(gmlp_dims, gmlp_act, add_bias=True, dropout_keep=gmlp_dropout, phase_train=phase_train, wd=wd, scope='glimpse_mlp', init_weights=gmlp_init_w, frozen=gmlp_frozen, model=model) ############################ # Controller MLP definition ############################ cmlp_dims = [crnn_dim] + [ctrl_mlp_dim] * \ (num_ctrl_mlp_layers - 1) + [9] cmlp_act = [tf.nn.relu] * (num_ctrl_mlp_layers - 1) + [None] cmlp_dropout = None pt = pretrain_net or pretrain_ctrl_net if pt: log.info('Loading pretrained controller MLP weights from {}'.format( pt)) h5f = h5py.File(pt, 'r') cmlp_init_w = [{'w': h5f['ctrl_mlp_w_{}'.format(ii)][:], 'b': h5f['ctrl_mlp_b_{}'.format(ii)][:]} for ii in xrange(num_ctrl_mlp_layers)] cmlp_frozen = [freeze_ctrl_rnn] * num_ctrl_mlp_layers else: cmlp_init_w = None cmlp_frozen = [freeze_ctrl_rnn] * num_ctrl_mlp_layers cmlp = nn.mlp(cmlp_dims, cmlp_act, add_bias=True, dropout_keep=cmlp_dropout, phase_train=phase_train, wd=wd, scope='ctrl_mlp', init_weights=cmlp_init_w, frozen=cmlp_frozen, model=model) ########################### # Attention CNN definition ########################### acnn_filters = attn_cnn_filter_size acnn_nlayers = len(acnn_filters) acnn_channels = [acnn_inp_depth] + attn_cnn_depth acnn_pool = attn_cnn_pool acnn_act = [tf.nn.relu] * acnn_nlayers acnn_use_bn = [use_bn] * acnn_nlayers pt = pretrain_net or pretrain_attn_net if pt: log.info('Loading pretrained attention CNN weights from {}'.format( pt)) h5f = h5py.File(pt, 'r') acnn_init_w = [{'w': h5f['attn_cnn_w_{}'.format(ii)][:], 'b': h5f['attn_cnn_b_{}'.format(ii)][:]} for ii in xrange(acnn_nlayers)] for ii in xrange(acnn_nlayers): for tt in xrange(timespan): for w in ['beta', 'gamma']: key = 'attn_cnn_{}_{}_{}'.format(ii, tt, w) acnn_init_w[ii]['{}_{}'.format(w, tt)] = h5f[key][:] acnn_frozen = [freeze_attn_net] * acnn_nlayers else: acnn_init_w = None acnn_frozen = [freeze_attn_net] * acnn_nlayers acnn = nn.cnn(acnn_filters, acnn_channels, acnn_pool, acnn_act, acnn_use_bn, phase_train=phase_train, wd=wd, scope='attn_cnn', model=model, init_weights=acnn_init_w, frozen=acnn_frozen) x_patch = [None] * timespan h_acnn = [None] * timespan h_acnn_last = [None] * timespan ############################ # Attention MLP definition ############################ acnn_subsample = np.array(acnn_pool).prod() acnn_h = filter_height / acnn_subsample acnn_w = filter_width / acnn_subsample amlp_inp_dim = acnn_h * acnn_w * acnn_channels[-1] core_depth = attn_mlp_depth core_dim = acnn_h * acnn_w * core_depth amlp_dims = [amlp_inp_dim] + [core_dim] * num_attn_mlp_layers amlp_act = [tf.nn.relu] * num_attn_mlp_layers amlp_dropout = None pt = pretrain_net or pretrain_attn_net if pt: log.info('Loading pretrained attention MLP weights from {}'.format( pt)) h5f = h5py.File(pt, 'r') amlp_init_w = [{'w': h5f['attn_mlp_w_{}'.format(ii)][:], 'b': h5f['attn_mlp_b_{}'.format(ii)][:]} for ii in xrange(num_attn_mlp_layers)] amlp_frozen = [freeze_attn_net] * num_attn_mlp_layers else: amlp_init_w = None amlp_frozen = [freeze_attn_net] * num_attn_mlp_layers amlp = nn.mlp(amlp_dims, amlp_act, dropout_keep=amlp_dropout, phase_train=phase_train, wd=wd, scope='attn_mlp', init_weights=amlp_init_w, frozen=amlp_frozen, model=model) ########################## # Score MLP definition ########################## pt = pretrain_net if pt: log.info('Loading score mlp weights from {}'.format(pt)) h5f = h5py.File(pt, 'r') smlp_init_w = [{'w': h5f['score_mlp_w_{}'.format(ii)][:], 'b': h5f['score_mlp_b_{}'.format(ii)][:]} for ii in xrange(1)] else: smlp_init_w = None smlp = nn.mlp([crnn_dim + core_dim, 1], [tf.sigmoid], wd=wd, scope='score_mlp', init_weights=smlp_init_w, model=model) s_out = [None] * timespan ############################# # Attention DCNN definition ############################# adcnn_filters = attn_dcnn_filter_size adcnn_nlayers = len(adcnn_filters) adcnn_unpool = attn_dcnn_pool adcnn_act = [tf.nn.relu] * adcnn_nlayers adcnn_channels = [attn_mlp_depth] + attn_dcnn_depth adcnn_bn_nlayers = adcnn_nlayers # adcnn_bn_nlayers = adcnn_nlayers - 1 adcnn_use_bn = [use_bn] * adcnn_bn_nlayers + \ [False] * (adcnn_nlayers - adcnn_bn_nlayers) adcnn_skip_ch = [0] + acnn_channels[:: -1][1:] pt = pretrain_net or pretrain_attn_net if pt: log.info('Loading pretrained attention DCNN weights from {}'.format( pt)) h5f = h5py.File(pt, 'r') adcnn_init_w = [{'w': h5f['attn_dcnn_w_{}'.format(ii)][:], 'b': h5f['attn_dcnn_b_{}'.format(ii)][:]} for ii in xrange(adcnn_nlayers)] for ii in xrange(adcnn_bn_nlayers): for tt in xrange(timespan): for w in ['beta', 'gamma']: key = 'attn_dcnn_{}_{}_{}'.format(ii, tt, w) adcnn_init_w[ii]['{}_{}'.format(w, tt)] = h5f[key][:] adcnn_frozen = [freeze_attn_net] * adcnn_nlayers else: adcnn_init_w = None adcnn_frozen = [freeze_attn_net] * adcnn_nlayers adcnn = nn.dcnn(adcnn_filters, adcnn_channels, adcnn_unpool, adcnn_act, use_bn=adcnn_use_bn, skip_ch=adcnn_skip_ch, phase_train=phase_train, wd=wd, model=model, init_weights=adcnn_init_w, frozen=adcnn_frozen, scope='attn_dcnn') h_adcnn = [None] * timespan ########################## # Attention box ########################## attn_ctr_norm = [None] * timespan attn_lg_size = [None] * timespan attn_ctr = [None] * timespan attn_size = [None] * timespan attn_lg_var = [None] * timespan attn_lg_gamma = [None] * timespan attn_gamma = [None] * timespan attn_box_lg_gamma = [None] * timespan attn_top_left = [None] * timespan attn_bot_right = [None] * timespan attn_box = [None] * timespan iou_soft_box = [None] * timespan const_ones = tf.ones(tf.pack([num_ex, filter_height, filter_width, 1])) attn_box_beta = tf.constant([-5.0]) attn_box_gamma = [None] * timespan ############################# # Groundtruth attention box ############################# # [B, T, 2] attn_ctr_gt, attn_size_gt, attn_lg_var_gt, attn_box_gt, \ attn_top_left_gt, attn_bot_right_gt = \ base.get_gt_attn(y_gt, padding_ratio=attn_box_padding_ratio, center_shift_ratio=0.0, min_padding=padding + 4) attn_ctr_gt_noise, attn_size_gt_noise, attn_lg_var_gt_noise, \ attn_box_gt_noise, \ attn_top_left_gt_noise, attn_bot_right_gt_noise = \ base.get_gt_attn(y_gt, padding_ratio=tf.random_uniform( tf.pack([num_ex, timespan, 1]), attn_box_padding_ratio - gt_box_pad_noise, attn_box_padding_ratio + gt_box_pad_noise), center_shift_ratio=tf.random_uniform( tf.pack([num_ex, timespan, 2]), -gt_box_ctr_noise, gt_box_ctr_noise), min_padding=padding + 4) attn_ctr_norm_gt = base.get_normalized_center( attn_ctr_gt, inp_height, inp_width) attn_lg_size_gt = base.get_normalized_size( attn_size_gt, inp_height, inp_width) ########################## # Groundtruth mix ########################## grd_match_cum = tf.zeros(tf.pack([num_ex, timespan])) # Scale mix ratio on different timesteps. if knob_use_timescale: gt_knob_time_scale = tf.reshape( 1.0 + tf.log(1.0 + tf.to_float(tf.range(timespan)) * 3.0), [1, timespan, 1]) else: gt_knob_time_scale = tf.ones([1, timespan, 1]) # Mix in groundtruth box. global_step_box = tf.maximum(0.0, global_step - knob_box_offset) gt_knob_prob_box = tf.train.exponential_decay( knob_base, global_step_box, steps_per_knob_decay, knob_decay, staircase=False) gt_knob_prob_box = tf.minimum( 1.0, gt_knob_prob_box * gt_knob_time_scale) gt_knob_box = tf.to_float(tf.random_uniform( tf.pack([num_ex, timespan, 1]), 0, 1.0) <= gt_knob_prob_box) model['gt_knob_prob_box'] = gt_knob_prob_box[0, 0, 0] # Mix in groundtruth segmentation. global_step_segm = tf.maximum(0.0, global_step - knob_segm_offset) gt_knob_prob_segm = tf.train.exponential_decay( knob_base, global_step_segm, steps_per_knob_decay, knob_decay, staircase=False) gt_knob_prob_segm = tf.minimum( 1.0, gt_knob_prob_segm * gt_knob_time_scale) gt_knob_segm = tf.to_float(tf.random_uniform( tf.pack([num_ex, timespan, 1]), 0, 1.0) <= gt_knob_prob_segm) model['gt_knob_prob_segm'] = gt_knob_prob_segm[0, 0, 0] ########################## # Segmentation output ########################## y_out = [None] * timespan y_out_lg_gamma = [None] * timespan y_out_beta = tf.constant([-5.0]) ########################## # Computation graph ########################## for tt in xrange(timespan): # Controller CNN ccnn_inp = tf.concat(3, [x, canvas]) acnn_inp = ccnn_inp h_ccnn[tt] = ccnn(ccnn_inp) _h_ccnn = h_ccnn[tt] h_ccnn_last = _h_ccnn[-1] # Controller RNN [B, R1] if ctrl_rnn_inp_struct == 'dense': crnn_inp = tf.reshape(h_ccnn_last, [-1, crnn_inp_dim]) crnn_state[tt], crnn_g_i[tt], crnn_g_f[tt], crnn_g_o[ tt] = crnn_cell(crnn_inp, crnn_state[tt - 1]) h_crnn[tt] = tf.slice( crnn_state[tt], [0, crnn_dim], [-1, crnn_dim]) ctrl_out = cmlp(h_crnn[tt])[-1] elif ctrl_rnn_inp_struct == 'attn': crnn_inp = tf.reshape( h_ccnn_last, [-1, glimpse_map_dim, glimpse_feat_dim]) crnn_state[tt] = [None] * (num_ctrl_rnn_iter + 1) crnn_g_i[tt] = [None] * num_ctrl_rnn_iter crnn_g_f[tt] = [None] * num_ctrl_rnn_iter crnn_g_o[tt] = [None] * num_ctrl_rnn_iter h_crnn[tt] = [None] * num_ctrl_rnn_iter crnn_state[tt][-1] = tf.zeros(tf.pack([num_ex, crnn_dim * 2])) crnn_glimpse_map[tt] = [None] * num_ctrl_rnn_iter crnn_glimpse_map[tt][0] = tf.ones( tf.pack([num_ex, glimpse_map_dim, 1])) / glimpse_map_dim # Inner glimpse RNN for tt2 in xrange(num_ctrl_rnn_iter): crnn_glimpse = tf.reduce_sum( crnn_inp * crnn_glimpse_map[tt][tt2], [1]) crnn_state[tt][tt2], crnn_g_i[tt][tt2], crnn_g_f[tt][tt2], \ crnn_g_o[tt][tt2] = crnn_cell( crnn_glimpse, crnn_state[tt][tt2 - 1]) h_crnn[tt][tt2] = tf.slice( crnn_state[tt][tt2], [0, crnn_dim], [-1, crnn_dim]) h_gmlp = gmlp(h_crnn[tt][tt2]) if tt2 < num_ctrl_rnn_iter - 1: crnn_glimpse_map[tt][ tt2 + 1] = tf.expand_dims(h_gmlp[-1], 2) ctrl_out = cmlp(h_crnn[tt][-1])[-1] attn_ctr_norm[tt] = tf.slice(ctrl_out, [0, 0], [-1, 2]) attn_lg_size[tt] = tf.slice(ctrl_out, [0, 2], [-1, 2]) # Restrict to (-1, 1), (-inf, 0) if squash_ctrl_params: attn_ctr_norm[tt] = tf.tanh(attn_ctr_norm[tt]) attn_lg_size[tt] = -tf.nn.softplus(attn_lg_size[tt]) attn_ctr[tt], attn_size[tt] = base.get_unnormalized_attn( attn_ctr_norm[tt], attn_lg_size[tt], inp_height, inp_width) attn_lg_var[tt] = tf.zeros(tf.pack([num_ex, 2])) if fixed_gamma: attn_lg_gamma[tt] = tf.constant([0.0]) # attn_box_lg_gamma[tt] = tf.constant([2.0]) y_out_lg_gamma[tt] = tf.constant([2.0]) else: attn_lg_gamma[tt] = tf.slice(ctrl_out, [0, 6], [-1, 1]) # attn_box_lg_gamma[tt] = tf.slice(ctrl_out, [0, 7], [-1, 1]) y_out_lg_gamma[tt] = tf.slice(ctrl_out, [0, 8], [-1, 1]) attn_box_lg_gamma[tt] = tf.slice(ctrl_out, [0, 7], [-1, 1]) attn_gamma[tt] = tf.reshape( tf.exp(attn_lg_gamma[tt]), [-1, 1, 1, 1]) attn_box_gamma[tt] = tf.reshape(tf.exp( attn_box_lg_gamma[tt]), [-1, 1, 1, 1]) y_out_lg_gamma[tt] = tf.reshape(y_out_lg_gamma[tt], [-1, 1, 1, 1]) # Initial filters (predicted) filter_y = base.get_gaussian_filter( attn_ctr[tt][:, 0], attn_size[tt][:, 0], attn_lg_var[tt][:, 0], inp_height, filter_height) filter_x = base.get_gaussian_filter( attn_ctr[tt][:, 1], attn_size[tt][:, 1], attn_lg_var[tt][:, 1], inp_width, filter_width) filter_y_inv = tf.transpose(filter_y, [0, 2, 1]) filter_x_inv = tf.transpose(filter_x, [0, 2, 1]) # Attention box attn_box[tt] = base.extract_patch( const_ones * attn_box_gamma[tt], filter_y_inv, filter_x_inv, 1) attn_box[tt] = tf.sigmoid(attn_box[tt] + attn_box_beta) attn_box[tt] = tf.reshape(attn_box[tt], [-1, 1, inp_height, inp_width]) # Kick in GT bbox. if use_knob: if fixed_order: attn_ctr_gtm = attn_ctr_gt_noise[:, tt, :] attn_delta_gtm = attn_delta_gt_noise[:, tt, :] attn_size_gtm = attn_size_gt_noise[:, tt, :] else: iou_soft_box[tt] = base.f_inter( attn_box[tt], attn_box_gt) / \ base.f_union(attn_box[tt], attn_box_gt, eps=1e-5) grd_match = base.f_greedy_match( iou_soft_box[tt], grd_match_cum) # [B, T, 1] grd_match = tf.expand_dims(grd_match, 2) attn_ctr_gtm = tf.reduce_sum( grd_match * attn_ctr_gt_noise, 1) attn_size_gtm = tf.reduce_sum( grd_match * attn_size_gt_noise, 1) attn_ctr[tt] = phase_train_f * gt_knob_box[:, tt, 0: 1] * \ attn_ctr_gtm + \ (1 - phase_train_f * gt_knob_box[:, tt, 0: 1]) * \ attn_ctr[tt] attn_size[tt] = phase_train_f * gt_knob_box[:, tt, 0: 1] * \ attn_size_gtm + \ (1 - phase_train_f * gt_knob_box[:, tt, 0: 1]) * \ attn_size[tt] attn_top_left[tt], attn_bot_right[tt] = base.get_box_coord( attn_ctr[tt], attn_size[tt]) filter_y = base.get_gaussian_filter( attn_ctr[tt][:, 0], attn_size[tt][:, 0], attn_lg_var[tt][:, 0], inp_height, filter_height) filter_x = base.get_gaussian_filter( attn_ctr[tt][:, 1], attn_size[tt][:, 1], attn_lg_var[tt][:, 1], inp_width, filter_width) filter_y_inv = tf.transpose(filter_y, [0, 2, 1]) filter_x_inv = tf.transpose(filter_x, [0, 2, 1]) tf.stop_gradient(filter_y) tf.stop_gradient(filter_x) # Attended patch [B, A, A, D] x_patch[tt] = attn_gamma[tt] * base.extract_patch( acnn_inp, filter_y, filter_x, acnn_inp_depth) # CNN [B, A, A, D] => [B, RH2, RW2, RD2] h_acnn[tt] = acnn(x_patch[tt]) h_acnn_last[tt] = h_acnn[tt][-1] # Dense segmentation network [B, R] => [B, M] amlp_inp = h_acnn_last[tt] amlp_inp = tf.reshape(amlp_inp, [-1, amlp_inp_dim]) h_core = amlp(amlp_inp)[-1] h_core_img = tf.reshape( h_core, [-1, acnn_h, acnn_w, attn_mlp_depth]) # DCNN skip = [None] + h_acnn[tt][::-1][1:] + [x_patch[tt]] h_adcnn[tt] = adcnn(h_core_img, skip=skip) # Output y_out[tt] = base.extract_patch( h_adcnn[tt][-1], filter_y_inv, filter_x_inv, 1) y_out[tt] = tf.exp(y_out_lg_gamma[tt]) * y_out[tt] + y_out_beta y_out[tt] = tf.sigmoid(y_out[tt]) y_out[tt] = tf.reshape(y_out[tt], [-1, 1, inp_height, inp_width]) # Scoring network if ctrl_rnn_inp_struct == 'dense': smlp_inp = tf.concat(1, [h_crnn[tt], h_core]) elif ctrl_rnn_inp_struct == 'attn': smlp_inp = tf.concat(1, [h_crnn[tt][-1], h_core]) s_out[tt] = smlp(smlp_inp)[-1] # Here is the knob kick in GT segmentations at this timestep. # [B, N, 1, 1] if use_knob: _gt_knob_segm = tf.expand_dims( tf.expand_dims(gt_knob_segm[:, tt, 0: 1], 2), 3) if fixed_order: _y_out = tf.expand_dims(y_gt[:, tt, :, :], 3) else: grd_match = tf.expand_dims(grd_match, 3) _y_out = tf.expand_dims(tf.reduce_sum( grd_match * y_gt, 1), 3) # Add independent uniform noise to groundtruth. _noise = tf.random_uniform( tf.pack([num_ex, inp_height, inp_width, 1]), 0, gt_segm_noise) _y_out = _y_out - _y_out * _noise _y_out = phase_train_f * _gt_knob_segm * _y_out + \ (1 - phase_train_f * _gt_knob_segm) * \ tf.reshape(y_out[tt], [-1, inp_height, inp_width, 1]) else: _y_out = tf.reshape(y_out[tt], [-1, inp_height, inp_width, 1]) canvas = tf.stop_gradient(tf.maximum(_y_out, canvas)) ######################### # Model outputs ######################### s_out = tf.concat(1, s_out) model['s_out'] = s_out y_out = tf.concat(1, y_out) model['y_out'] = y_out attn_box = tf.concat(1, attn_box) model['attn_box'] = attn_box x_patch = tf.concat(1, [tf.expand_dims(x_patch[tt], 1) for tt in xrange(timespan)]) model['x_patch'] = x_patch attn_top_left = tf.concat(1, [tf.expand_dims(tmp, 1) for tmp in attn_top_left]) attn_bot_right = tf.concat(1, [tf.expand_dims(tmp, 1) for tmp in attn_bot_right]) attn_ctr = tf.concat(1, [tf.expand_dims(tmp, 1) for tmp in attn_ctr]) attn_size = tf.concat(1, [tf.expand_dims(tmp, 1) for tmp in attn_size]) attn_lg_gamma = tf.concat(1, [tf.expand_dims(tmp, 1) for tmp in attn_lg_gamma]) attn_box_lg_gamma = tf.concat(1, [tf.expand_dims(tmp, 1) for tmp in attn_box_lg_gamma]) y_out_lg_gamma = tf.concat(1, [tf.expand_dims(tmp, 1) for tmp in y_out_lg_gamma]) model['attn_ctr'] = attn_ctr model['attn_size'] = attn_size model['attn_top_left'] = attn_top_left model['attn_bot_right'] = attn_bot_right model['attn_box_gt'] = attn_box_gt attn_ctr_norm = tf.concat(1, [tf.expand_dims(tmp, 1) for tmp in attn_ctr_norm]) attn_lg_size = tf.concat(1, [tf.expand_dims(tmp, 1) for tmp in attn_lg_size]) model['attn_ctr_norm'] = attn_ctr_norm model['attn_lg_size'] = attn_lg_size attn_params = tf.concat(2, [attn_ctr_norm, attn_lg_size]) attn_params_gt = tf.concat(2, [attn_ctr_norm_gt, attn_lg_size_gt]) ######################### # Loss function ######################### y_gt_shape = tf.shape(y_gt) num_ex_f = tf.to_float(y_gt_shape[0]) max_num_obj = tf.to_float(y_gt_shape[1]) ############################ # Box loss ############################ if fixed_order: # [B, T] for fixed order. iou_soft_box = base.f_iou(attn_box, attn_box_gt, pairwise=False) else: if use_knob: # [B, T, T] for matching. iou_soft_box = tf.concat( 1, [tf.expand_dims(iou_soft_box[tt], 1) for tt in xrange(timespan)]) else: iou_soft_box = base.f_iou(attn_box, attn_box_gt, timespan, pairwise=True) identity_match = base.get_identity_match(num_ex, timespan, s_gt) if fixed_order: match_box = identity_match else: match_box = base.f_segm_match(iou_soft_box, s_gt) model['match_box'] = match_box match_sum_box = tf.reduce_sum(match_box, reduction_indices=[2]) match_count_box = tf.reduce_sum(match_sum_box, reduction_indices=[1]) match_count_box = tf.maximum(1.0, match_count_box) # [B] if fixed order, [B, T] if matching. if fixed_order: iou_soft_box_mask = iou_soft_box else: iou_soft_box_mask = tf.reduce_sum(iou_soft_box * match_box, [1]) iou_soft_box = tf.reduce_sum(iou_soft_box_mask, [1]) iou_soft_box = tf.reduce_sum(iou_soft_box / match_count_box) / num_ex_f if box_loss_fn == 'mse': box_loss = base.f_match_loss( attn_params, attn_params_gt, match_box, timespan, base.f_squared_err, model=model) elif box_loss_fn == 'huber': box_loss = base.f_match_loss( attn_params, attn_params_gt, match_box, timespan, base.f_huber) elif box_loss_fn == 'iou': box_loss = -iou_soft_box elif box_loss_fn == 'wt_cov': box_loss = -base.f_weighted_coverage(iou_soft_box, attn_box_gt) elif box_loss_fn == 'bce': box_loss_fn = base.f_match_loss( y_out, y_gt, match_box, timespan, f_bce) else: raise Exception('Unknown box_loss_fn: {}'.format(box_loss_fn)) model['box_loss'] = box_loss box_loss_coeff = tf.constant(1.0) model['box_loss_coeff'] = box_loss_coeff tf.add_to_collection('losses', box_loss_coeff * box_loss) ############################## # Segmentation loss ############################## # IoU (soft) iou_soft_pairwise = base.f_iou(y_out, y_gt, timespan, pairwise=True) real_match = base.f_segm_match(iou_soft_pairwise, s_gt) if fixed_order: iou_soft = base.f_iou(y_out, y_gt, pairwise=False) match = identity_match else: iou_soft = iou_soft_pairwise match = real_match model['match'] = match match_sum = tf.reduce_sum(match, reduction_indices=[2]) match_count = tf.reduce_sum(match_sum, reduction_indices=[1]) match_count = tf.maximum(1.0, match_count) # Weighted coverage (soft) wt_cov_soft = base.f_weighted_coverage(iou_soft_pairwise, y_gt) model['wt_cov_soft'] = wt_cov_soft unwt_cov_soft = base.f_unweighted_coverage( iou_soft_pairwise, match_count) model['unwt_cov_soft'] = unwt_cov_soft # [B] if fixed order, [B, T] if matching. if fixed_order: iou_soft_mask = iou_soft else: iou_soft_mask = tf.reduce_sum(iou_soft * match, [1]) iou_soft = tf.reduce_sum(iou_soft_mask, [1]) iou_soft = tf.reduce_sum(iou_soft / match_count) / num_ex_f model['iou_soft'] = iou_soft if segm_loss_fn == 'iou': segm_loss = -iou_soft elif segm_loss_fn == 'wt_cov': segm_loss = -wt_cov_soft elif segm_loss_fn == 'bce': segm_loss = f_match_bce(y_out, y_gt, match, timespan) else: raise Exception('Unknown segm_loss_fn: {}'.format(segm_loss_fn)) model['segm_loss'] = segm_loss segm_loss_coeff = tf.constant(1.0) tf.add_to_collection('losses', segm_loss_coeff * segm_loss) #################### # Score loss #################### conf_loss = base.f_conf_loss(s_out, match, timespan, use_cum_min=True) model['conf_loss'] = conf_loss tf.add_to_collection('losses', loss_mix_ratio * conf_loss) #################### # Total loss #################### total_loss = tf.add_n(tf.get_collection('losses'), name='total_loss') model['loss'] = total_loss #################### # Optimizer #################### learn_rate = tf.train.exponential_decay( base_learn_rate, global_step, steps_per_learn_rate_decay, learn_rate_decay, staircase=True) model['learn_rate'] = learn_rate eps = 1e-7 train_step = GradientClipOptimizer( tf.train.AdamOptimizer(learn_rate, epsilon=eps), clip=clip_gradient).minimize(total_loss, global_step=global_step) model['train_step'] = train_step #################### # Statistics #################### # Here statistics (hard measures) is always using matching. y_out_hard = tf.to_float(y_out > 0.5) iou_hard = base.f_iou(y_out_hard, y_gt, timespan, pairwise=True) wt_cov_hard = base.f_weighted_coverage(iou_hard, y_gt) model['wt_cov_hard'] = wt_cov_hard unwt_cov_hard = base.f_unweighted_coverage(iou_hard, match_count) model['unwt_cov_hard'] = unwt_cov_hard iou_hard_mask = tf.reduce_sum(iou_hard * real_match, [1]) iou_hard = tf.reduce_sum(tf.reduce_sum(iou_hard_mask, [1]) / match_count) / num_ex_f model['iou_hard'] = iou_hard dice = base.f_dice(y_out_hard, y_gt, timespan, pairwise=True) dice = tf.reduce_sum(tf.reduce_sum( dice * real_match, reduction_indices=[1, 2]) / match_count) / \ num_ex_f model['dice'] = dice model['count_acc'] = base.f_count_acc(s_out, s_gt) model['dic'] = base.f_dic(s_out, s_gt, abs=False) model['dic_abs'] = base.f_dic(s_out, s_gt, abs=True) ################################ # Controller output statistics ################################ if fixed_gamma: attn_lg_gamma_mean = tf.constant([0.0]) attn_box_lg_gamma_mean = tf.constant([2.0]) y_out_lg_gamma_mean = tf.constant([2.0]) else: attn_lg_gamma_mean = tf.reduce_sum( attn_lg_gamma) / num_ex_f / timespan attn_box_lg_gamma_mean = tf.reduce_sum( attn_box_lg_gamma) / num_ex_f / timespan y_out_lg_gamma_mean = tf.reduce_sum( y_out_lg_gamma) / num_ex_f / timespan model['attn_lg_gamma_mean'] = attn_lg_gamma_mean model['attn_box_lg_gamma_mean'] = attn_box_lg_gamma_mean model['y_out_lg_gamma_mean'] = y_out_lg_gamma_mean #################### # Debug gradients #################### ctrl_mlp_b_grad = tf.gradients(total_loss, model['ctrl_mlp_b_0']) model['ctrl_mlp_b_grad'] = ctrl_mlp_b_grad[0] #################### # Glimpse #################### # T * T2 * [H', W'] => [T, T2, H', W'] if ctrl_rnn_inp_struct == 'attn': crnn_glimpse_map = tf.concat( 1, [tf.expand_dims(tf.concat( 1, [tf.expand_dims(crnn_glimpse_map[tt][tt2], 1) for tt2 in xrange(num_ctrl_rnn_iter)]), 1) for tt in xrange(timespan)]) crnn_glimpse_map = tf.reshape( crnn_glimpse_map, [-1, timespan, num_ctrl_rnn_iter, crnn_h, crnn_w]) model['ctrl_rnn_glimpse_map'] = crnn_glimpse_map return model
def build_inference_network(self, x): """Build inference part of the network.""" config = self.config is_training = self.is_training # Activation functions (combining normalization). if config.norm_field == "batch": log.info("Using batch normalization") log.info("Setting sigma={:.3e}".format(config.sigma_init)) log.info("Setting sigma learnable={}".format(config.learn_sigma)) log.info("Setting L1={:.3e}".format(config.l1_reg)) conv_act_fn = [ get_bn_act(act=get_tf_fn(aa), is_training=is_training, sigma_init=config.sigma_init, affine=config.norm_affine, l1_reg=config.l1_reg, mask=config.bn_mask[ii], l1_collection=self.l1_collection, learn_sigma=config.learn_sigma, dtype=self.dtype()) for ii, aa in enumerate(config.conv_act_fn) ] elif config.norm_field == "batch_ms": log.info("Using mean subtracted batch normalization") log.info("Setting L1={:.3e}".format(config.l1_reg)) conv_act_fn = [ get_bnms_act(act=get_tf_fn(aa), is_training=is_training, affine=config.norm_affine, l1_reg=config.l1_reg, mask=config.bn_mask[ii], l1_collection=self.l1_collection, dtype=self.dtype()) for ii, aa in enumerate(config.conv_act_fn) ] elif config.norm_field == "layer": log.info("Using layer normalization") log.info("Setting sigma={:.3e}".format(config.sigma_init)) log.info("Setting sigma learnable={}".format(config.learn_sigma)) log.info("Setting L1={:.3e}".format(config.l1_reg)) conv_act_fn = [ get_ln_act(act=get_tf_fn(aa), sigma_init=config.sigma_init, affine=config.norm_affine, l1_reg=config.l1_reg, l1_collection=self.l1_collection, learn_sigma=config.learn_sigma, dtype=self.dtype()) for ii, aa in enumerate(config.conv_act_fn) ] elif config.norm_field == "layer_ms": log.info("Using mean subtracted layer normalization") log.info("Setting L1={:.3e}".format(config.l1_reg)) conv_act_fn = [ get_lnms_act(act=get_tf_fn(aa), affine=config.norm_affine, l1_reg=config.l1_reg, l1_collection=self.l1_collection, dtype=self.dtype()) for ii, aa in enumerate(config.conv_act_fn) ] elif config.norm_field == "div": log.info("Using divisive normalization") log.info("Setting sigma={:.3e}".format(config.sigma_init)) log.info("Setting sigma learnable={}".format(config.learn_sigma)) log.info("Setting L1={:.3e}".format(config.l1_reg)) conv_act_fn = [ get_dn_act(sum_window=config.sum_window[ii], sup_window=config.sup_window[ii], act=get_tf_fn(aa), sigma_init=config.sigma_init, affine=config.norm_affine, l1_reg=config.l1_reg, l1_collection=self.l1_collection, learn_sigma=config.learn_sigma, dtype=self.dtype()) for ii, aa in enumerate(config.conv_act_fn) ] elif config.norm_field == "div_ms": log.info("Using mean subtracted divisive normalization") log.info("Setting L1={:.3e}".format(config.l1_reg)) conv_act_fn = [ get_dnms_act(sum_window=config.sum_window[ii], act=get_tf_fn(aa), affine=config.norm_affine, l1_reg=config.l1_reg, l1_collection=self.l1_collection, dtype=self.dtype()) for ii, aa in enumerate(config.conv_act_fn) ] elif config.norm_field == "no" or config.norm_field is None: log.info("Not using normalization") log.info("Setting L1={:.3e}".format(config.l1_reg)) conv_act_fn = [ get_reg_act(get_tf_fn(aa), l1_reg=config.l1_reg, l1_collection=self.l1_collection) for aa in config.conv_act_fn ] else: raise Exception("Unknown normalization \"{}\"".format( config.norm_field)) # Pooling functions. pool_fn = [get_tf_fn(pp) for pp in config.pool_fn] # CNN function. cnn_fn = lambda x: nn.cnn(x, config.filter_size, strides=config.strides, pool_fn=pool_fn, pool_size=config.pool_size, pool_strides=config.pool_strides, act_fn=conv_act_fn, dtype=self.dtype(), add_bias=True, init_std=config.conv_init_std, init_method=config.conv_init_method, wd=config.wd) # MLP function. mlp_act_fn = [get_tf_fn(aa) for aa in config.mlp_act_fn] mlp_fn = lambda x: nn.mlp(x, config.mlp_dims, is_training=is_training, act_fn=mlp_act_fn, dtype=self.dtype(), init_std=config.mlp_init_std, init_method=config.mlp_init_method, dropout=config.mlp_dropout, wd=config.wd) # Prediction model. h = cnn_fn(x) # [print(n.name) for n in tf.get_default_graph().as_graph_def().node] h = tf.reshape(h, [-1, config.mlp_dims[0]]) logits = mlp_fn(h) return logits
def get_model(opt, device='/cpu:0'): """The attention model""" model = {} timespan = opt['timespan'] inp_height = opt['inp_height'] inp_width = opt['inp_width'] inp_depth = opt['inp_depth'] padding = opt['padding'] filter_height = opt['filter_height'] filter_width = opt['filter_width'] ctrl_cnn_filter_size = opt['ctrl_cnn_filter_size'] ctrl_cnn_depth = opt['ctrl_cnn_depth'] ctrl_cnn_pool = opt['ctrl_cnn_pool'] ctrl_rnn_hid_dim = opt['ctrl_rnn_hid_dim'] num_ctrl_mlp_layers = opt['num_ctrl_mlp_layers'] ctrl_mlp_dim = opt['ctrl_mlp_dim'] # New parameters for double attention. num_ctrl_rnn_iter = opt['num_ctrl_rnn_iter'] num_glimpse_mlp_layers = opt['num_glimpse_mlp_layers'] attn_cnn_filter_size = opt['attn_cnn_filter_size'] attn_cnn_depth = opt['attn_cnn_depth'] attn_cnn_pool = opt['attn_cnn_pool'] attn_dcnn_filter_size = opt['attn_dcnn_filter_size'] attn_dcnn_depth = opt['attn_dcnn_depth'] attn_dcnn_pool = opt['attn_dcnn_pool'] attn_rnn_hid_dim = opt['attn_rnn_hid_dim'] mlp_dropout_ratio = opt['mlp_dropout'] num_attn_mlp_layers = opt['num_attn_mlp_layers'] attn_mlp_depth = opt['attn_mlp_depth'] attn_box_padding_ratio = opt['attn_box_padding_ratio'] wd = opt['weight_decay'] use_bn = opt['use_bn'] use_gt_attn = opt['use_gt_attn'] segm_loss_fn = opt['segm_loss_fn'] box_loss_fn = opt['box_loss_fn'] loss_mix_ratio = opt['loss_mix_ratio'] base_learn_rate = opt['base_learn_rate'] learn_rate_decay = opt['learn_rate_decay'] steps_per_learn_rate_decay = opt['steps_per_learn_rate_decay'] use_attn_rnn = opt['use_attn_rnn'] use_knob = opt['use_knob'] knob_base = opt['knob_base'] knob_decay = opt['knob_decay'] steps_per_knob_decay = opt['steps_per_knob_decay'] use_canvas = opt['use_canvas'] knob_box_offset = opt['knob_box_offset'] knob_segm_offset = opt['knob_segm_offset'] knob_use_timescale = opt['knob_use_timescale'] gt_selector = opt['gt_selector'] gt_box_ctr_noise = opt['gt_box_ctr_noise'] gt_box_pad_noise = opt['gt_box_pad_noise'] gt_segm_noise = opt['gt_segm_noise'] downsample_canvas = opt['downsample_canvas'] pretrain_cnn = opt['pretrain_cnn'] cnn_share_weights = opt['cnn_share_weights'] squash_ctrl_params = opt['squash_ctrl_params'] use_iou_box = opt['use_iou_box'] clip_gradient = opt['clip_gradient'] rnd_hflip = opt['rnd_hflip'] rnd_vflip = opt['rnd_vflip'] rnd_transpose = opt['rnd_transpose'] rnd_colour = opt['rnd_colour'] ############################ # Input definition ############################ with tf.device(get_device_fn(device)): # Input image, [B, H, W, D] x = tf.placeholder('float', [None, inp_height, inp_width, inp_depth]) x_shape = tf.shape(x) num_ex = x_shape[0] # Groundtruth segmentation, [B, T, H, W] y_gt = tf.placeholder('float', [None, timespan, inp_height, inp_width]) # Groundtruth confidence score, [B, T] s_gt = tf.placeholder('float', [None, timespan]) # Whether in training stage. phase_train = tf.placeholder('bool') phase_train_f = tf.to_float(phase_train) model['x'] = x model['y_gt'] = y_gt model['s_gt'] = s_gt model['phase_train'] = phase_train # Global step global_step = tf.Variable(0.0) ############################### # Random input transformation ############################### x, y_gt = img.random_transformation( x, y_gt, padding, phase_train, rnd_hflip=rnd_hflip, rnd_vflip=rnd_vflip, rnd_transpose=rnd_transpose, rnd_colour=rnd_colour) model['x_trans'] = x model['y_gt_trans'] = y_gt ############################ # Canvas: external memory ############################ if use_canvas: canvas = tf.zeros(tf.pack([num_ex, inp_height, inp_width, 1])) ccnn_inp_depth = inp_depth + 1 acnn_inp_depth = inp_depth + 1 else: ccnn_inp_depth = inp_depth acnn_inp_depth = inp_depth ############################ # Controller CNN definition ############################ ccnn_filters = ctrl_cnn_filter_size ccnn_nlayers = len(ccnn_filters) ccnn_channels = [ccnn_inp_depth] + ctrl_cnn_depth ccnn_pool = ctrl_cnn_pool ccnn_act = [tf.nn.relu] * ccnn_nlayers ccnn_use_bn = [use_bn] * ccnn_nlayers if pretrain_cnn: h5f = h5py.File(pretrain_cnn, 'r') ccnn_init_w = [{'w': h5f['cnn_w_{}'.format(ii)][:], 'b': h5f['cnn_b_{}'.format(ii)][:]} for ii in xrange(ccnn_nlayers)] ccnn_frozen = True else: ccnn_init_w = None ccnn_frozen = None ccnn = nn.cnn(ccnn_filters, ccnn_channels, ccnn_pool, ccnn_act, ccnn_use_bn, phase_train=phase_train, wd=wd, scope='ctrl_cnn', model=model, init_weights=ccnn_init_w, frozen=ccnn_frozen) h_ccnn = [None] * timespan ############################ # Controller RNN definition ############################ ccnn_subsample = np.array(ccnn_pool).prod() crnn_h = inp_height / ccnn_subsample crnn_w = inp_width / ccnn_subsample crnn_dim = ctrl_rnn_hid_dim canvas_dim = inp_height * inp_width / (ccnn_subsample ** 2) glimpse_map_dim = crnn_h * crnn_w glimpse_feat_dim = ccnn_channels[-1] # crnn_inp_dim = crnn_h * crnn_w * ccnn_channels[-1] crnn_state = [None] * (timespan + 1) crnn_glimpse_map = [None] * timespan crnn_g_i = [None] * timespan crnn_g_f = [None] * timespan crnn_g_o = [None] * timespan h_crnn = [None] * timespan crnn_cell = nn.lstm(glimpse_feat_dim, crnn_dim, wd=wd, scope='ctrl_lstm', model=model) ############################ # Glimpse MLP definition ############################ gmlp_dims = [crnn_dim] * num_glimpse_mlp_layers + [glimpse_map_dim] gmlp_act = [tf.nn.relu] * \ (num_glimpse_mlp_layers - 1) + [tf.nn.softmax] gmlp_dropout = None gmlp = nn.mlp(gmlp_dims, gmlp_act, add_bias=True, dropout_keep=gmlp_dropout, phase_train=phase_train, wd=wd, scope='glimpse_mlp', model=model) ############################ # Controller MLP definition ############################ cmlp_dims = [crnn_dim] + [ctrl_mlp_dim] * \ (num_ctrl_mlp_layers - 1) + [9] cmlp_act = [tf.nn.relu] * (num_ctrl_mlp_layers - 1) + [None] cmlp_dropout = None # cmlp_dropout = [1.0 - mlp_dropout_ratio] * num_ctrl_mlp_layers cmlp = nn.mlp(cmlp_dims, cmlp_act, add_bias=True, dropout_keep=cmlp_dropout, phase_train=phase_train, wd=wd, scope='ctrl_mlp', model=model) ############################ # Attention CNN definition ############################ acnn_filters = attn_cnn_filter_size acnn_nlayers = len(acnn_filters) acnn_channels = [acnn_inp_depth] + attn_cnn_depth acnn_pool = attn_cnn_pool acnn_act = [tf.nn.relu] * acnn_nlayers acnn_use_bn = [use_bn] * acnn_nlayers if cnn_share_weights: ccnn_shared_weights = [] for ii in xrange(ccnn_nlayers): ccnn_shared_weights.append( {'w': model['ctrl_cnn_w_{}'.format(ii)], 'b': model['ctrl_cnn_b_{}'.format(ii)]}) else: ccnn_shared_weights = None acnn = nn.cnn(acnn_filters, acnn_channels, acnn_pool, acnn_act, acnn_use_bn, phase_train=phase_train, wd=wd, scope='attn_cnn', model=model, shared_weights=ccnn_shared_weights) x_patch = [None] * timespan h_acnn = [None] * timespan h_acnn_last = [None] * timespan ############################ # Attention RNN definition ############################ acnn_subsample = np.array(acnn_pool).prod() arnn_h = filter_height / acnn_subsample arnn_w = filter_width / acnn_subsample if use_attn_rnn: arnn_dim = attn_rnn_hid_dim arnn_inp_dim = arnn_h * arnn_w * acnn_channels[-1] arnn_state = [None] * (timespan + 1) arnn_g_i = [None] * timespan arnn_g_f = [None] * timespan arnn_g_o = [None] * timespan arnn_state[-1] = tf.zeros(tf.pack([num_ex, arnn_dim * 2])) arnn_cell = nn.lstm(arnn_inp_dim, arnn_dim, wd=wd, scope='attn_lstm') amlp_inp_dim = arnn_dim else: amlp_inp_dim = arnn_h * arnn_w * acnn_channels[-1] ############################ # Attention MLP definition ############################ core_depth = attn_mlp_depth core_dim = arnn_h * arnn_w * core_depth amlp_dims = [amlp_inp_dim] + [core_dim] * num_attn_mlp_layers amlp_act = [tf.nn.relu] * num_attn_mlp_layers amlp_dropout = None # amlp_dropout = [1.0 - mlp_dropout_ratio] * num_attn_mlp_layers amlp = nn.mlp(amlp_dims, amlp_act, dropout_keep=amlp_dropout, phase_train=phase_train, wd=wd, scope='attn_mlp', model=model) # DCNN [B, RH, RW, MD] => [B, A, A, 1] adcnn_filters = attn_dcnn_filter_size adcnn_nlayers = len(adcnn_filters) adcnn_unpool = attn_dcnn_pool adcnn_act = [tf.nn.relu] * adcnn_nlayers adcnn_channels = [attn_mlp_depth] + attn_dcnn_depth adcnn_use_bn = [use_bn] * adcnn_nlayers adcnn_skip_ch = [0] + acnn_channels[::-1][1:] adcnn = nn.dcnn(adcnn_filters, adcnn_channels, adcnn_unpool, adcnn_act, use_bn=adcnn_use_bn, skip_ch=adcnn_skip_ch, phase_train=phase_train, wd=wd, model=model, scope='attn_dcnn') h_adcnn = [None] * timespan ########################## # Score MLP definition ########################## smlp = nn.mlp([crnn_dim, 1], [tf.sigmoid], wd=wd, scope='score_mlp', model=model) s_out = [None] * timespan ########################## # Attention box ########################## attn_ctr_norm = [None] * timespan attn_lg_size = [None] * timespan attn_ctr = [None] * timespan attn_size = [None] * timespan attn_lg_var = [None] * timespan attn_lg_gamma = [None] * timespan attn_gamma = [None] * timespan attn_box_lg_gamma = [None] * timespan attn_top_left = [None] * timespan attn_bot_right = [None] * timespan attn_box = [None] * timespan iou_soft_box = [None] * timespan const_ones = tf.ones( tf.pack([num_ex, filter_height, filter_width, 1])) attn_box_beta = tf.constant([-5.0]) attn_box_gamma = [None] * timespan ############################# # Groundtruth attention box ############################# # [B, T, 2] attn_ctr_gt, attn_size_gt, attn_lg_var_gt, attn_box_gt, \ attn_top_left_gt, attn_bot_right_gt = \ base.get_gt_attn(y_gt, padding_ratio=attn_box_padding_ratio, center_shift_ratio=0.0) attn_ctr_gt_noise, attn_size_gt_noise, attn_lg_var_gt_noise, \ attn_box_gt_noise, \ attn_top_left_gt_noise, attn_bot_right_gt_noise = \ base.get_gt_attn(y_gt, padding_ratio=tf.random_uniform( tf.pack([num_ex, timespan, 1]), attn_box_padding_ratio - gt_box_pad_noise, attn_box_padding_ratio + gt_box_pad_noise), center_shift_ratio=tf.random_uniform( tf.pack([num_ex, timespan, 2]), -gt_box_ctr_noise, gt_box_ctr_noise)) ########################## # Groundtruth mix ########################## grd_match_cum = tf.zeros(tf.pack([num_ex, timespan])) # Add a bias on every entry so there is no duplicate match # [1, N] iou_bias_eps = 1e-7 iou_bias = tf.expand_dims(tf.to_float( tf.reverse(tf.range(timespan), [True])) * iou_bias_eps, 0) # Scale mix ratio on different timesteps. gt_knob_time_scale = tf.reshape( 1.0 + tf.log(1.0 + tf.to_float(tf.range(timespan)) * 3.0 * float(knob_use_timescale)), [1, timespan, 1]) # Mix in groundtruth box. global_step_box = tf.maximum(0.0, global_step - knob_box_offset) gt_knob_prob_box = tf.train.exponential_decay( knob_base, global_step_box, steps_per_knob_decay, knob_decay, staircase=False) gt_knob_prob_box = tf.minimum( 1.0, gt_knob_prob_box * gt_knob_time_scale) gt_knob_box = tf.to_float(tf.random_uniform( tf.pack([num_ex, timespan, 1]), 0, 1.0) <= gt_knob_prob_box) model['gt_knob_prob_box'] = gt_knob_prob_box[0, 0, 0] # Mix in groundtruth segmentation. global_step_segm = tf.maximum(0.0, global_step - knob_segm_offset) gt_knob_prob_segm = tf.train.exponential_decay( knob_base, global_step_segm, steps_per_knob_decay, knob_decay, staircase=False) gt_knob_prob_segm = tf.minimum( 1.0, gt_knob_prob_segm * gt_knob_time_scale) gt_knob_segm = tf.to_float(tf.random_uniform( tf.pack([num_ex, timespan, 1]), 0, 1.0) <= gt_knob_prob_segm) model['gt_knob_prob_segm'] = gt_knob_prob_segm[0, 0, 0] ########################## # Segmentation output ########################## y_out = [None] * timespan y_out_lg_gamma = [None] * timespan y_out_beta = tf.constant([-5.0]) ########################## # Computation graph ########################## if not use_canvas: h_ccnn = ccnn(x) for tt in xrange(timespan): # Controller CNN [B, H, W, D] => [B, RH1, RW1, RD1] if use_canvas: ccnn_inp = tf.concat(3, [x, canvas]) acnn_inp = ccnn_inp h_ccnn[tt] = ccnn(ccnn_inp) _h_ccnn = h_ccnn[tt] else: ccnn_inp = x acnn_inp = x _h_ccnn = h_ccnn h_ccnn_last = _h_ccnn[-1] # crnn_inp = tf.reshape(h_ccnn_last, [-1, crnn_inp_dim]) crnn_inp = tf.reshape( h_ccnn_last, [-1, glimpse_map_dim, glimpse_feat_dim]) crnn_state[tt] = [None] * (num_ctrl_rnn_iter + 1) crnn_g_i[tt] = [None] * num_ctrl_rnn_iter crnn_g_f[tt] = [None] * num_ctrl_rnn_iter crnn_g_o[tt] = [None] * num_ctrl_rnn_iter h_crnn[tt] = [None] * num_ctrl_rnn_iter crnn_state[tt][-1] = tf.zeros(tf.pack([num_ex, crnn_dim * 2])) # if tt == 0: # crnn_state[tt][-1] = tf.zeros(tf.pack([num_ex, crnn_dim * 2])) # else: # crnn_state[tt][-1] = crnn_state[tt - 1][num_ctrl_rnn_iter - 1] crnn_glimpse_map[tt] = [None] * num_ctrl_rnn_iter crnn_glimpse_map[tt][0] = tf.ones( tf.pack([num_ex, glimpse_map_dim, 1])) / glimpse_map_dim for tt2 in xrange(num_ctrl_rnn_iter): crnn_glimpse = tf.reduce_sum( crnn_inp * crnn_glimpse_map[tt][tt2], [1]) crnn_state[tt][tt2], crnn_g_i[tt][tt2], crnn_g_f[tt][tt2], \ crnn_g_o[tt][tt2] = \ crnn_cell(crnn_glimpse, crnn_state[tt][tt2 - 1]) h_crnn[tt][tt2] = tf.slice( crnn_state[tt][tt2], [0, crnn_dim], [-1, crnn_dim]) h_gmlp = gmlp(h_crnn[tt][tt2]) if tt2 < num_ctrl_rnn_iter - 1: crnn_glimpse_map[tt][ tt2 + 1] = tf.expand_dims(h_gmlp[-1], 2) ctrl_out = cmlp(h_crnn[tt][-1])[-1] attn_ctr_norm[tt] = tf.slice(ctrl_out, [0, 0], [-1, 2]) attn_lg_size[tt] = tf.slice(ctrl_out, [0, 2], [-1, 2]) # Restrict to (-1, 1), (-inf, 0) if squash_ctrl_params: attn_ctr_norm[tt] = tf.tanh(attn_ctr_norm[tt]) attn_lg_size[tt] = -tf.nn.softplus(attn_lg_size[tt]) attn_ctr[tt], attn_size[tt] = base.get_unnormalized_attn( attn_ctr_norm[tt], attn_lg_size[tt], inp_height, inp_width) attn_lg_var[tt] = tf.zeros(tf.pack([num_ex, 2])) attn_lg_gamma[tt] = tf.slice(ctrl_out, [0, 6], [-1, 1]) attn_box_lg_gamma[tt] = tf.slice(ctrl_out, [0, 7], [-1, 1]) y_out_lg_gamma[tt] = tf.slice(ctrl_out, [0, 8], [-1, 1]) attn_gamma[tt] = tf.reshape( tf.exp(attn_lg_gamma[tt]), [-1, 1, 1, 1]) attn_box_gamma[tt] = tf.reshape(tf.exp( attn_box_lg_gamma[tt]), [-1, 1, 1, 1]) y_out_lg_gamma[tt] = tf.reshape(y_out_lg_gamma[tt], [-1, 1, 1, 1]) # Initial filters (predicted) filter_y = get_gaussian_filter( attn_ctr[tt][:, 0], attn_size[tt][:, 0], attn_lg_var[tt][:, 0], inp_height, filter_height) filter_x = get_gaussian_filter( attn_ctr[tt][:, 1], attn_size[tt][:, 1], attn_lg_var[tt][:, 1], inp_width, filter_width) filter_y_inv = tf.transpose(filter_y, [0, 2, 1]) filter_x_inv = tf.transpose(filter_x, [0, 2, 1]) # Attention box if use_iou_box: _idx_map = get_idx_map( tf.pack([num_ex, inp_height, inp_width])) attn_top_left[tt], attn_bot_right[tt] = get_box_coord( attn_ctr[tt], attn_size[tt]) attn_box[tt] = get_filled_box_idx( _idx_map, attn_top_left[tt], attn_bot_right[tt]) attn_box[tt] = tf.reshape(attn_box[tt], [-1, 1, inp_height, inp_width]) else: attn_box[tt] = extract_patch(const_ones * attn_box_gamma[tt], filter_y_inv, filter_x_inv, 1) attn_box[tt] = tf.sigmoid(attn_box[tt] + attn_box_beta) attn_box[tt] = tf.reshape(attn_box[tt], [-1, 1, inp_height, inp_width]) # Here is the knob kick in GT bbox. if use_knob: # IOU [B, 1, T] # [B, 1, H, W] * [B, T, H, W] = [B, T] if use_iou_box: _top_left = tf.expand_dims(attn_top_left[tt], 1) _bot_right = tf.expand_dims(attn_bot_right[tt], 1) iou_soft_box[tt] = f_iou_box( _top_left, _bot_right, attn_top_left_gt, attn_bot_right_gt) iou_soft_box[tt] += iou_bias else: iou_soft_box[tt] = f_inter(attn_box[tt], attn_box_gt) / \ f_union(attn_box[tt], attn_box_gt, eps=1e-5) grd_match = f_greedy_match(iou_soft_box[tt], grd_match_cum) if gt_selector == 'greedy_match': # Add in the cumulative matching to not double count. grd_match_cum += grd_match # [B, T, 1] grd_match = tf.expand_dims(grd_match, 2) attn_ctr_gt_match = tf.reduce_sum( grd_match * attn_ctr_gt_noise, 1) attn_size_gt_match = tf.reduce_sum( grd_match * attn_size_gt_noise, 1) _gt_knob_box = gt_knob_box attn_ctr[tt] = phase_train_f * _gt_knob_box[:, tt, 0: 1] * \ attn_ctr_gt_match + \ (1 - phase_train_f * _gt_knob_box[:, tt, 0: 1]) * \ attn_ctr[tt] attn_size[tt] = phase_train_f * _gt_knob_box[:, tt, 0: 1] * \ attn_size_gt_match + \ (1 - phase_train_f * _gt_knob_box[:, tt, 0: 1]) * \ attn_size[tt] attn_top_left[tt], attn_bot_right[tt] = get_box_coord( attn_ctr[tt], attn_size[tt]) filter_y = get_gaussian_filter( attn_ctr[tt][:, 0], attn_size[tt][:, 0], attn_lg_var[tt][:, 0], inp_height, filter_height) filter_x = get_gaussian_filter( attn_ctr[tt][:, 1], attn_size[tt][:, 1], attn_lg_var[tt][:, 1], inp_width, filter_width) filter_y_inv = tf.transpose(filter_y, [0, 2, 1]) filter_x_inv = tf.transpose(filter_x, [0, 2, 1]) # Attended patch [B, A, A, D] x_patch[tt] = attn_gamma[tt] * extract_patch( acnn_inp, filter_y, filter_x, acnn_inp_depth) # CNN [B, A, A, D] => [B, RH2, RW2, RD2] h_acnn[tt] = acnn(x_patch[tt]) h_acnn_last[tt] = h_acnn[tt][-1] if use_attn_rnn: # RNN [B, T, R2] arnn_inp = tf.reshape(h_acnn_last[tt], [-1, arnn_inp_dim]) arnn_state[tt], arnn_g_i[tt], arnn_g_f[tt], arnn_g_o[tt] = \ arnn_cell(arnn_inp, arnn_state[tt - 1]) # Scoring network s_out[tt] = smlp(h_crnn[tt][-1])[-1] # Dense segmentation network [B, R] => [B, M] if use_attn_rnn: h_arnn = tf.slice( arnn_state[tt], [0, arnn_dim], [-1, arnn_dim]) amlp_inp = h_arnn else: amlp_inp = h_acnn_last[tt] amlp_inp = tf.reshape(amlp_inp, [-1, amlp_inp_dim]) h_core = amlp(amlp_inp)[-1] h_core = tf.reshape(h_core, [-1, arnn_h, arnn_w, attn_mlp_depth]) # DCNN skip = [None] + h_acnn[tt][::-1][1:] + [x_patch[tt]] h_adcnn[tt] = adcnn(h_core, skip=skip) # Output y_out[tt] = extract_patch( h_adcnn[tt][-1], filter_y_inv, filter_x_inv, 1) y_out[tt] = tf.exp(y_out_lg_gamma[tt]) * y_out[tt] + y_out_beta y_out[tt] = tf.sigmoid(y_out[tt]) y_out[tt] = tf.reshape(y_out[tt], [-1, 1, inp_height, inp_width]) # Here is the knob kick in GT segmentations at this timestep. # [B, N, 1, 1] if use_canvas: if use_knob: _gt_knob_segm = tf.expand_dims( tf.expand_dims(gt_knob_segm[:, tt, 0: 1], 2), 3) # [B, N, 1, 1] grd_match = tf.expand_dims(grd_match, 3) _y_out = tf.expand_dims(tf.reduce_sum( grd_match * y_gt, 1), 3) # Add independent uniform noise to groundtruth. _noise = tf.random_uniform( tf.pack([num_ex, inp_height, inp_width, 1]), 0, gt_segm_noise) _y_out = _y_out - _y_out * _noise _y_out = phase_train_f * _gt_knob_segm * _y_out + \ (1 - phase_train_f * _gt_knob_segm) * \ tf.reshape(y_out[tt], [-1, inp_height, inp_width, 1]) else: _y_out = tf.reshape(y_out[tt], [-1, inp_height, inp_width, 1]) canvas += tf.stop_gradient(_y_out) ######################### # Model outputs ######################### s_out = tf.concat(1, s_out) model['s_out'] = s_out y_out = tf.concat(1, y_out) model['y_out'] = y_out attn_box = tf.concat(1, attn_box) model['attn_box'] = attn_box x_patch = tf.concat(1, [tf.expand_dims(x_patch[tt], 1) for tt in xrange(timespan)]) model['x_patch'] = x_patch ######################### # Loss function ######################### y_gt_shape = tf.shape(y_gt) num_ex_f = tf.to_float(y_gt_shape[0]) max_num_obj = tf.to_float(y_gt_shape[1]) ############################ # Box loss ############################ if use_knob: iou_soft_box = tf.concat(1, [tf.expand_dims(iou_soft_box[tt], 1) for tt in xrange(timespan)]) else: iou_soft_box = f_iou(attn_box, attn_box_gt, timespan, pairwise=True) model['iou_soft_box'] = iou_soft_box model['attn_box_gt'] = attn_box_gt match_box = f_segm_match(iou_soft_box, s_gt) model['match_box'] = match_box match_sum_box = tf.reduce_sum(match_box, reduction_indices=[2]) match_count_box = tf.reduce_sum( match_sum_box, reduction_indices=[1]) match_count_box = tf.maximum(1.0, match_count_box) iou_soft_box_mask = tf.reduce_sum(iou_soft_box * match_box, [1]) iou_soft_box = tf.reduce_sum(tf.reduce_sum(iou_soft_box_mask, [1]) / match_count_box) / num_ex_f gt_wt_box = f_coverage_weight(attn_box_gt) wt_iou_soft_box = tf.reduce_sum(tf.reduce_sum( iou_soft_box_mask * gt_wt_box, [1]) / match_count_box) / num_ex_f if box_loss_fn == 'iou': box_loss = -iou_soft_box elif box_loss_fn == 'wt_iou': box_loss = -wt_iou_soft_box elif box_loss_fn == 'wt_cov': box_loss = -f_weighted_coverage(iou_soft_box, attn_box_gt) elif box_loss_fn == 'bce': box_loss = f_match_bce(attn_box, attn_box_gt, match_box, timespan) else: raise Exception('Unknown box_loss_fn: {}'.format(box_loss_fn)) model['box_loss'] = box_loss box_loss_coeff = tf.constant(1.0) tf.add_to_collection('losses', box_loss_coeff * box_loss) ############################## # Segmentation loss ############################## # IoU (soft) iou_soft = f_iou(y_out, y_gt, timespan, pairwise=True) match = f_segm_match(iou_soft, s_gt) model['match'] = match match_sum = tf.reduce_sum(match, reduction_indices=[2]) match_count = tf.reduce_sum(match_sum, reduction_indices=[1]) match_count = tf.maximum(1.0, match_count) # Weighted coverage (soft) wt_cov_soft = f_weighted_coverage(iou_soft, y_gt) model['wt_cov_soft'] = wt_cov_soft unwt_cov_soft = f_unweighted_coverage(iou_soft, match_count) model['unwt_cov_soft'] = unwt_cov_soft # IOU (soft) iou_soft_mask = tf.reduce_sum(iou_soft * match, [1]) iou_soft = tf.reduce_sum(tf.reduce_sum(iou_soft_mask, [1]) / match_count) / num_ex_f model['iou_soft'] = iou_soft gt_wt = f_coverage_weight(y_gt) wt_iou_soft = tf.reduce_sum(tf.reduce_sum(iou_soft_mask * gt_wt, [1]) / match_count) / num_ex_f model['wt_iou_soft'] = wt_iou_soft if segm_loss_fn == 'iou': segm_loss = -iou_soft elif segm_loss_fn == 'wt_iou': segm_loss = -wt_iou_soft elif segm_loss_fn == 'wt_cov': segm_loss = -wt_cov_soft elif segm_loss_fn == 'bce': segm_loss = f_match_bce(y_out, y_gt, match, timespan) else: raise Exception('Unknown segm_loss_fn: {}'.format(segm_loss_fn)) model['segm_loss'] = segm_loss segm_loss_coeff = tf.constant(1.0) tf.add_to_collection('losses', segm_loss_coeff * segm_loss) #################### # Score loss #################### conf_loss = f_conf_loss(s_out, match, timespan, use_cum_min=True) model['conf_loss'] = conf_loss tf.add_to_collection('losses', loss_mix_ratio * conf_loss) #################### # Total loss #################### total_loss = tf.add_n(tf.get_collection( 'losses'), name='total_loss') model['loss'] = total_loss #################### # Optimizer #################### learn_rate = tf.train.exponential_decay( base_learn_rate, global_step, steps_per_learn_rate_decay, learn_rate_decay, staircase=True) model['learn_rate'] = learn_rate eps = 1e-7 train_step = GradientClipOptimizer( tf.train.AdamOptimizer(learn_rate, epsilon=eps), clip=clip_gradient).minimize(total_loss, global_step=global_step) model['train_step'] = train_step #################### # Statistics #################### # [B, M, N] * [B, M, N] => [B] * [B] => [1] y_out_hard = tf.to_float(y_out > 0.5) iou_hard = f_iou(y_out_hard, y_gt, timespan, pairwise=True) wt_cov_hard = f_weighted_coverage(iou_hard, y_gt) model['wt_cov_hard'] = wt_cov_hard unwt_cov_hard = f_unweighted_coverage(iou_hard, match_count) model['unwt_cov_hard'] = unwt_cov_hard # [B, T] iou_hard_mask = tf.reduce_sum(iou_hard * match, [1]) iou_hard = tf.reduce_sum(tf.reduce_sum(iou_hard_mask, [1]) / match_count) / num_ex_f model['iou_hard'] = iou_hard wt_iou_hard = tf.reduce_sum(tf.reduce_sum(iou_hard_mask * gt_wt, [1]) / match_count) / num_ex_f model['wt_iou_hard'] = wt_iou_hard dice = f_dice(y_out_hard, y_gt, timespan, pairwise=True) dice = tf.reduce_sum(tf.reduce_sum(dice * match, [1, 2]) / match_count) / num_ex_f model['dice'] = dice model['count_acc'] = f_count_acc(s_out, s_gt) model['dic'] = f_dic(s_out, s_gt, abs=False) model['dic_abs'] = f_dic(s_out, s_gt, abs=True) ################################ # Controller output statistics ################################ attn_top_left = tf.concat(1, [tf.expand_dims(tmp, 1) for tmp in attn_top_left]) attn_bot_right = tf.concat(1, [tf.expand_dims(tmp, 1) for tmp in attn_bot_right]) attn_ctr = tf.concat(1, [tf.expand_dims(tmp, 1) for tmp in attn_ctr]) attn_size = tf.concat(1, [tf.expand_dims(tmp, 1) for tmp in attn_size]) attn_lg_gamma = tf.concat(1, [tf.expand_dims(tmp, 1) for tmp in attn_lg_gamma]) attn_box_lg_gamma = tf.concat(1, [tf.expand_dims(tmp, 1) for tmp in attn_box_lg_gamma]) y_out_lg_gamma = tf.concat(1, [tf.expand_dims(tmp, 1) for tmp in y_out_lg_gamma]) attn_lg_gamma_mean = tf.reduce_sum(attn_lg_gamma) / num_ex_f / timespan attn_box_lg_gamma_mean = tf.reduce_sum( attn_box_lg_gamma) / num_ex_f / timespan y_out_lg_gamma_mean = tf.reduce_sum( y_out_lg_gamma) / num_ex_f / timespan model['attn_ctr'] = attn_ctr model['attn_size'] = attn_size model['attn_top_left'] = attn_top_left model['attn_bot_right'] = attn_bot_right model['attn_lg_gamma_mean'] = attn_lg_gamma_mean model['attn_box_lg_gamma_mean'] = attn_box_lg_gamma_mean model['y_out_lg_gamma_mean'] = y_out_lg_gamma_mean return model
def get_model(opt, device='/cpu:0'): """The attention model""" model = {} timespan = opt['timespan'] inp_height = opt['inp_height'] inp_width = opt['inp_width'] inp_depth = opt['inp_depth'] padding = opt['padding'] filter_height = opt['filter_height'] filter_width = opt['filter_width'] attn_cnn_filter_size = opt['attn_cnn_filter_size'] attn_cnn_depth = opt['attn_cnn_depth'] attn_cnn_pool = opt['attn_cnn_pool'] attn_dcnn_filter_size = opt['attn_dcnn_filter_size'] attn_dcnn_depth = opt['attn_dcnn_depth'] attn_dcnn_pool = opt['attn_dcnn_pool'] mlp_dropout_ratio = opt['mlp_dropout'] num_attn_mlp_layers = opt['num_attn_mlp_layers'] attn_mlp_depth = opt['attn_mlp_depth'] attn_box_padding_ratio = opt['attn_box_padding_ratio'] wd = opt['weight_decay'] use_bn = opt['use_bn'] segm_loss_fn = opt['segm_loss_fn'] base_learn_rate = opt['base_learn_rate'] learn_rate_decay = opt['learn_rate_decay'] steps_per_learn_rate_decay = opt['steps_per_learn_rate_decay'] gt_box_ctr_noise = opt['gt_box_ctr_noise'] gt_box_pad_noise = opt['gt_box_pad_noise'] gt_segm_noise = opt['gt_segm_noise'] clip_gradient = opt['clip_gradient'] fixed_order = opt['fixed_order'] rnd_hflip = opt['rnd_hflip'] rnd_vflip = opt['rnd_vflip'] rnd_transpose = opt['rnd_transpose'] rnd_colour = opt['rnd_colour'] ############################ # Input definition ############################ with tf.device(base.get_device_fn(device)): # Input image, [B, H, W, D] x = tf.placeholder('float', [None, inp_height, inp_width, inp_depth], name='x') x_shape = tf.shape(x) num_ex = x_shape[0] # Groundtruth segmentation, [B, T, H, W] y_gt = tf.placeholder('float', [None, timespan, inp_height, inp_width], name='y_gt') # Groundtruth confidence score, [B, T] s_gt = tf.placeholder('float', [None, timespan], name='s_gt') # Order in which we feed in the samples order = tf.placeholder('int32', [None, timespan], name='order') # Whether in training stage. phase_train = tf.placeholder('bool', name='phase_train') phase_train_f = tf.to_float(phase_train) model['x'] = x model['y_gt'] = y_gt model['s_gt'] = s_gt model['phase_train'] = phase_train model['order'] = order # Global step global_step = tf.Variable(0.0, name='global_step') ############################### # Random input transformation ############################### x, y_gt = img.random_transformation( x, y_gt, padding, phase_train, rnd_hflip=rnd_hflip, rnd_vflip=rnd_vflip, rnd_transpose=rnd_transpose, rnd_colour=rnd_colour) model['x_trans'] = x model['y_gt_trans'] = y_gt ############################ # Canvas: external memory ############################ canvas = tf.zeros(tf.pack([num_ex, inp_height, inp_width, 1])) acnn_inp_depth = inp_depth + 1 ########################### # Attention CNN definition ########################### acnn_filters = attn_cnn_filter_size acnn_nlayers = len(acnn_filters) acnn_channels = [acnn_inp_depth] + attn_cnn_depth acnn_pool = attn_cnn_pool acnn_act = [tf.nn.relu] * acnn_nlayers acnn_use_bn = [use_bn] * acnn_nlayers acnn = nn.cnn(acnn_filters, acnn_channels, acnn_pool, acnn_act, acnn_use_bn, phase_train=phase_train, wd=wd, scope='attn_cnn', model=model) x_patch = [None] * timespan h_acnn = [None] * timespan h_acnn_last = [None] * timespan ############################ # Attention MLP definition ############################ acnn_subsample = np.array(acnn_pool).prod() acnn_h = filter_height / acnn_subsample acnn_w = filter_width / acnn_subsample core_depth = attn_mlp_depth core_dim = acnn_h * acnn_w * core_depth amlp_inp_dim = acnn_h * acnn_w * acnn_channels[-1] amlp_dims = [amlp_inp_dim] + [core_dim] * num_attn_mlp_layers amlp_act = [tf.nn.relu] * num_attn_mlp_layers amlp_dropout = None # amlp_dropout = [1.0 - mlp_dropout_ratio] * num_attn_mlp_layers amlp = nn.mlp(amlp_dims, amlp_act, dropout_keep=amlp_dropout, phase_train=phase_train, wd=wd, scope='attn_mlp', model=model) ############################# # Attention DCNN definition ############################# adcnn_filters = attn_dcnn_filter_size adcnn_nlayers = len(adcnn_filters) adcnn_unpool = attn_dcnn_pool adcnn_act = [tf.nn.relu] * adcnn_nlayers adcnn_channels = [attn_mlp_depth] + attn_dcnn_depth adcnn_bn_nlayers = adcnn_nlayers # adcnn_bn_nlayers = adcnn_nlayers - 1 adcnn_use_bn = [use_bn] * adcnn_bn_nlayers + \ [False] * (adcnn_nlayers - adcnn_bn_nlayers) adcnn_skip_ch = [0] + acnn_channels[::-1][1:] adcnn = nn.dcnn(adcnn_filters, adcnn_channels, adcnn_unpool, adcnn_act, use_bn=adcnn_use_bn, skip_ch=adcnn_skip_ch, phase_train=phase_train, wd=wd, model=model, scope='attn_dcnn') h_adcnn = [None] * timespan ########################## # Attention box ########################## attn_ctr = [None] * timespan attn_size = [None] * timespan attn_top_left = [None] * timespan attn_bot_right = [None] * timespan attn_box = [None] * timespan ############################# # Groundtruth attention box ############################# attn_ctr_gt_noise, attn_size_gt_noise, attn_lg_var_gt_noise, \ attn_box_gt_noise, \ attn_top_left_gt_noise, attn_bot_right_gt_noise = \ base.get_gt_attn(y_gt, padding_ratio=tf.random_uniform( tf.pack([num_ex, timespan, 1]), attn_box_padding_ratio - gt_box_pad_noise, attn_box_padding_ratio + gt_box_pad_noise), center_shift_ratio=tf.random_uniform( tf.pack([num_ex, timespan, 2]), -gt_box_ctr_noise, gt_box_ctr_noise), min_padding=25.0) # Attention CNN definition ########################## # Segmentation output ########################## y_out = [None] * timespan y_out_lg_gamma = tf.constant([2.0]) y_out_beta = tf.constant([-5.0]) ########################## # Computation graph ########################## for tt in xrange(timespan): # Get a new greedy match based on order. if fixed_order: attn_ctr[tt] = attn_ctr_gt_noise[:, tt, :] attn_size[tt] = attn_size_gt_noise[:, tt, :] else: mask = _get_idx_mask(order[:, tt], timespan) # [B, T, 1] mask = tf.expand_dims(mask, 2) attn_ctr[tt] = tf.reduce_sum(mask * attn_ctr_gt_noise, 1) attn_size[tt] = tf.reduce_sum(mask * attn_size_gt_noise, 1) attn_top_left[tt], attn_bot_right[tt] = base.get_box_coord( attn_ctr[tt], attn_size[tt]) # [B, H, H'] filter_y = base.get_gaussian_filter( attn_ctr[tt][:, 0], attn_size[tt][:, 0], 0.0, inp_height, filter_height) # [B, W, W'] filter_x = base.get_gaussian_filter( attn_ctr[tt][:, 1], attn_size[tt][:, 1], 0.0, inp_width, filter_width) # [B, H', H] filter_y_inv = tf.transpose(filter_y, [0, 2, 1]) # [B, W', W] filter_x_inv = tf.transpose(filter_x, [0, 2, 1]) # Attended patch [B, A, A, D] acnn_inp = tf.concat(3, [x, canvas]) x_patch[tt] = base.extract_patch( acnn_inp, filter_y, filter_x, acnn_inp_depth) # CNN [B, A, A, D] => [B, RH2, RW2, RD2] h_acnn[tt] = acnn(x_patch[tt]) h_acnn_last[tt] = h_acnn[tt][-1] amlp_inp = h_acnn_last[tt] amlp_inp = tf.reshape(amlp_inp, [-1, amlp_inp_dim]) h_core = amlp(amlp_inp)[-1] h_core = tf.reshape(h_core, [-1, acnn_h, acnn_w, attn_mlp_depth]) # DCNN skip = [None] + h_acnn[tt][::-1][1:] + [x_patch[tt]] h_adcnn[tt] = adcnn(h_core, skip=skip) # Output y_out[tt] = base.extract_patch( h_adcnn[tt][-1], filter_y_inv, filter_x_inv, 1) y_out[tt] = tf.exp(y_out_lg_gamma) * y_out[tt] + y_out_beta y_out[tt] = tf.sigmoid(y_out[tt]) y_out[tt] = tf.reshape(y_out[tt], [-1, 1, inp_height, inp_width]) # Canvas if fixed_order: _y_out = y_gt[:, tt, :, :] else: mask = tf.expand_dims(mask, 3) _y_out = tf.reduce_sum(mask * y_gt, 1) _y_out = tf.expand_dims(_y_out, 3) # Add independent uniform noise to groundtruth. _noise = tf.random_uniform( tf.pack([num_ex, inp_height, inp_width, 1]), 0, gt_segm_noise) _y_out = _y_out - _y_out * _noise canvas += _y_out ######################### # Model outputs ######################### y_out = tf.concat(1, y_out) model['y_out'] = y_out x_patch = tf.concat(1, [tf.expand_dims(x_patch[tt], 1) for tt in xrange(timespan)]) model['x_patch'] = x_patch attn_top_left = tf.concat(1, [tf.expand_dims(tmp, 1) for tmp in attn_top_left]) attn_bot_right = tf.concat(1, [tf.expand_dims(tmp, 1) for tmp in attn_bot_right]) attn_ctr = tf.concat(1, [tf.expand_dims(tmp, 1) for tmp in attn_ctr]) attn_size = tf.concat(1, [tf.expand_dims(tmp, 1) for tmp in attn_size]) model['attn_ctr'] = attn_ctr model['attn_size'] = attn_size model['attn_top_left'] = attn_top_left model['attn_bot_right'] = attn_bot_right ######################### # Loss function ######################### y_gt_shape = tf.shape(y_gt) num_ex_f = tf.to_float(y_gt_shape[0]) max_num_obj = tf.to_float(y_gt_shape[1]) ############################## # Segmentation loss ############################## # Matching identity_match = base.get_identity_match(num_ex, timespan, s_gt) iou_soft_pairwise = base.f_iou(y_out, y_gt, timespan, pairwise=True) real_match = base.f_segm_match(iou_soft_pairwise, s_gt) if fixed_order: iou_soft = base.f_iou(y_out, y_gt, pairwise=False) match = identity_match else: iou_soft = iou_soft_pairwise match = real_match model['match'] = match match_sum = tf.reduce_sum(match, reduction_indices=[2]) match_count = tf.reduce_sum(match_sum, reduction_indices=[1]) match_count = tf.maximum(1.0, match_count) # Weighted coverage (soft) wt_cov_soft = base.f_weighted_coverage(iou_soft_pairwise, y_gt) model['wt_cov_soft'] = wt_cov_soft unwt_cov_soft = base.f_unweighted_coverage( iou_soft_pairwise, match_count) model['unwt_cov_soft'] = unwt_cov_soft # IOU (soft) if fixed_order: iou_soft_mask = iou_soft else: iou_soft_mask = tf.reduce_sum(iou_soft * match, [1]) iou_soft = tf.reduce_sum(iou_soft_mask, [1]) iou_soft = tf.reduce_sum(iou_soft / match_count) / num_ex_f model['iou_soft'] = iou_soft if segm_loss_fn == 'iou': segm_loss = -iou_soft elif segm_loss_fn == 'wt_iou': segm_loss = -wt_iou_soft elif segm_loss_fn == 'wt_cov': segm_loss = -wt_cov_soft elif segm_loss_fn == 'bce': segm_loss = base.f_match_bce(y_out, y_gt, match, timespan) else: raise Exception('Unknown segm_loss_fn: {}'.format(segm_loss_fn)) model['segm_loss'] = segm_loss segm_loss_coeff = tf.constant(1.0) tf.add_to_collection('losses', segm_loss_coeff * segm_loss) #################### # Total loss #################### total_loss = tf.add_n(tf.get_collection('losses'), name='total_loss') model['loss'] = total_loss #################### # Optimizer #################### learn_rate = tf.train.exponential_decay( base_learn_rate, global_step, steps_per_learn_rate_decay, learn_rate_decay, staircase=True) model['learn_rate'] = learn_rate eps = 1e-7 train_step = GradientClipOptimizer( tf.train.AdamOptimizer(learn_rate, epsilon=eps), clip=clip_gradient).minimize(total_loss, global_step=global_step) model['train_step'] = train_step #################### # Statistics #################### # Here statistics (hard measures) is always using matching. y_out_hard = tf.to_float(y_out > 0.5) iou_hard = base.f_iou(y_out_hard, y_gt, timespan, pairwise=True) wt_cov_hard = base.f_weighted_coverage(iou_hard, y_gt) model['wt_cov_hard'] = wt_cov_hard unwt_cov_hard = base.f_unweighted_coverage(iou_hard, match_count) model['unwt_cov_hard'] = unwt_cov_hard iou_hard_mask = tf.reduce_sum(iou_hard * real_match, [1]) iou_hard = tf.reduce_sum(tf.reduce_sum(iou_hard_mask, [1]) / match_count) / num_ex_f model['iou_hard'] = iou_hard dice = base.f_dice(y_out_hard, y_gt, timespan, pairwise=True) dice = tf.reduce_sum(tf.reduce_sum( dice * real_match, reduction_indices=[1, 2]) / match_count) / \ num_ex_f model['dice'] = dice return model
def get_model(opt, device='/cpu:0'): """The box model""" model = {} timespan = opt['timespan'] inp_height = opt['inp_height'] inp_width = opt['inp_width'] inp_depth = opt['inp_depth'] padding = opt['padding'] filter_height = opt['filter_height'] filter_width = opt['filter_width'] ctrl_cnn_filter_size = opt['ctrl_cnn_filter_size'] ctrl_cnn_depth = opt['ctrl_cnn_depth'] ctrl_cnn_pool = opt['ctrl_cnn_pool'] ctrl_rnn_hid_dim = opt['ctrl_rnn_hid_dim'] num_ctrl_mlp_layers = opt['num_ctrl_mlp_layers'] ctrl_mlp_dim = opt['ctrl_mlp_dim'] attn_box_padding_ratio = opt['attn_box_padding_ratio'] wd = opt['weight_decay'] use_bn = opt['use_bn'] box_loss_fn = opt['box_loss_fn'] base_learn_rate = opt['base_learn_rate'] learn_rate_decay = opt['learn_rate_decay'] steps_per_learn_rate_decay = opt['steps_per_learn_rate_decay'] pretrain_cnn = opt['pretrain_cnn'] squash_ctrl_params = opt['squash_ctrl_params'] clip_gradient = opt['clip_gradient'] fixed_order = opt['fixed_order'] ctrl_rnn_inp_struct = opt['ctrl_rnn_inp_struct'] # dense or attn num_ctrl_rnn_iter = opt['num_ctrl_rnn_iter'] num_glimpse_mlp_layers = opt['num_glimpse_mlp_layers'] rnd_hflip = opt['rnd_hflip'] rnd_vflip = opt['rnd_vflip'] rnd_transpose = opt['rnd_transpose'] rnd_colour = opt['rnd_colour'] ############################ # Input definition ############################ with tf.device(base.get_device_fn(device)): # Input image, [B, H, W, D] x = tf.placeholder('float', [None, inp_height, inp_width, inp_depth], name='x') x_shape = tf.shape(x) num_ex = x_shape[0] # Groundtruth segmentation, [B, T, H, W] y_gt = tf.placeholder('float', [None, timespan, inp_height, inp_width], name='y_gt') # Groundtruth confidence score, [B, T] s_gt = tf.placeholder('float', [None, timespan], name='s_gt') # Whether in training stage. phase_train = tf.placeholder('bool', name='phase_train') phase_train_f = tf.to_float(phase_train) model['x'] = x model['y_gt'] = y_gt model['s_gt'] = s_gt model['phase_train'] = phase_train # Global step global_step = tf.Variable(0.0, name='global_step') ############################### # Random input transformation ############################### x, y_gt = img.random_transformation( x, y_gt, padding, phase_train, rnd_hflip=rnd_hflip, rnd_vflip=rnd_vflip, rnd_transpose=rnd_transpose, rnd_colour=rnd_colour) model['x_trans'] = x model['y_gt_trans'] = y_gt ############################ # Canvas: external memory ############################ canvas = tf.zeros(tf.pack([num_ex, inp_height, inp_width, 1])) ccnn_inp_depth = inp_depth + 1 acnn_inp_depth = inp_depth + 1 ############################ # Controller CNN definition ############################ ccnn_filters = ctrl_cnn_filter_size ccnn_nlayers = len(ccnn_filters) ccnn_channels = [ccnn_inp_depth] + ctrl_cnn_depth ccnn_pool = ctrl_cnn_pool ccnn_act = [tf.nn.relu] * ccnn_nlayers ccnn_use_bn = [use_bn] * ccnn_nlayers if pretrain_cnn: log.info('Loading pretrained weights from {}'.format(pretrain_cnn)) h5f = h5py.File(pretrain_cnn, 'r') acnn_nlayers = 0 # Assuming acnn_nlayers is smaller than ccnn_nlayers. for ii in xrange(ccnn_nlayers): if 'attn_cnn_w_{}'.format(ii) in h5f: log.info('Loading attn_cnn_w_{}'.format(ii)) log.info('Loading attn_cnn_b_{}'.format(ii)) acnn_nlayers += 1 ccnn_init_w = [{'w': h5f['attn_cnn_w_{}'.format(ii)][:], 'b': h5f['attn_cnn_b_{}'.format(ii)][:]} for ii in xrange(acnn_nlayers)] for ii in xrange(acnn_nlayers): for tt in xrange(timespan): for w in ['beta', 'gamma']: ccnn_init_w[ii]['{}_{}'.format(w, tt)] = h5f[ 'attn_cnn_{}_{}_{}'.format(ii, tt, w)][:] ccnn_frozen = [True] * acnn_nlayers for ii in xrange(acnn_nlayers, ccnn_nlayers): ccnn_init_w.append(None) ccnn_frozen.append(False) else: ccnn_init_w = None ccnn_frozen = None ccnn = nn.cnn(ccnn_filters, ccnn_channels, ccnn_pool, ccnn_act, ccnn_use_bn, phase_train=phase_train, wd=wd, scope='ctrl_cnn', model=model, init_weights=ccnn_init_w, frozen=ccnn_frozen) h_ccnn = [None] * timespan ############################ # Controller RNN definition ############################ ccnn_subsample = np.array(ccnn_pool).prod() crnn_h = inp_height / ccnn_subsample crnn_w = inp_width / ccnn_subsample crnn_dim = ctrl_rnn_hid_dim canvas_dim = inp_height * inp_width / (ccnn_subsample ** 2) glimpse_map_dim = crnn_h * crnn_w glimpse_feat_dim = ccnn_channels[-1] if ctrl_rnn_inp_struct == 'dense': crnn_inp_dim = crnn_h * crnn_w * ccnn_channels[-1] elif ctrl_rnn_inp_struct == 'attn': crnn_inp_dim = glimpse_feat_dim crnn_state = [None] * (timespan + 1) crnn_glimpse_map = [None] * timespan crnn_g_i = [None] * timespan crnn_g_f = [None] * timespan crnn_g_o = [None] * timespan h_crnn = [None] * timespan crnn_state[-1] = tf.zeros(tf.pack([num_ex, crnn_dim * 2])) crnn_cell = nn.lstm(crnn_inp_dim, crnn_dim, wd=wd, scope='ctrl_lstm', model=model) ############################ # Glimpse MLP definition ############################ gmlp_dims = [crnn_dim] * num_glimpse_mlp_layers + [glimpse_map_dim] gmlp_act = [tf.nn.relu] * \ (num_glimpse_mlp_layers - 1) + [tf.nn.softmax] gmlp_dropout = None gmlp = nn.mlp(gmlp_dims, gmlp_act, add_bias=True, dropout_keep=gmlp_dropout, phase_train=phase_train, wd=wd, scope='glimpse_mlp', model=model) ############################ # Controller MLP definition ############################ cmlp_dims = [crnn_dim] + [ctrl_mlp_dim] * \ (num_ctrl_mlp_layers - 1) + [9] cmlp_act = [tf.nn.relu] * (num_ctrl_mlp_layers - 1) + [None] cmlp_dropout = None cmlp = nn.mlp(cmlp_dims, cmlp_act, add_bias=True, dropout_keep=cmlp_dropout, phase_train=phase_train, wd=wd, scope='ctrl_mlp', model=model) ########################## # Score MLP definition ########################## smlp = nn.mlp([crnn_dim, 1], [tf.sigmoid], wd=wd, scope='score_mlp') s_out = [None] * timespan ########################## # Attention box ########################## attn_ctr_norm = [None] * timespan attn_lg_size = [None] * timespan attn_lg_var = [None] * timespan attn_ctr = [None] * timespan attn_size = [None] * timespan attn_top_left = [None] * timespan attn_bot_right = [None] * timespan attn_box = [None] * timespan attn_box_lg_gamma = [None] * timespan attn_box_gamma = [None] * timespan const_ones = tf.ones(tf.pack([num_ex, filter_height, filter_width, 1])) attn_box_beta = tf.constant([-5.0]) iou_soft_box = [None] * timespan ############################# # Groundtruth attention box ############################# attn_top_left_gt, attn_bot_right_gt, attn_box_gt = base.get_gt_box( y_gt, padding_ratio=attn_box_padding_ratio, center_shift_ratio=0.0) attn_ctr_gt, attn_size_gt = base.get_box_ctr_size( attn_top_left_gt, attn_bot_right_gt) attn_ctr_norm_gt = base.get_normalized_center( attn_ctr_gt, inp_height, inp_width) attn_lg_size_gt = base.get_normalized_size( attn_size_gt, inp_height, inp_width) ########################## # Groundtruth mix ########################## grd_match_cum = tf.zeros(tf.pack([num_ex, timespan])) ########################## # Computation graph ########################## for tt in xrange(timespan): # Controller CNN ccnn_inp = tf.concat(3, [x, canvas]) acnn_inp = ccnn_inp h_ccnn[tt] = ccnn(ccnn_inp) _h_ccnn = h_ccnn[tt] h_ccnn_last = _h_ccnn[-1] # Controller RNN [B, R1] if ctrl_rnn_inp_struct == 'dense': crnn_inp = tf.reshape(h_ccnn_last, [-1, crnn_inp_dim]) crnn_state[tt], crnn_g_i[tt], crnn_g_f[tt], crnn_g_o[tt] = \ crnn_cell(crnn_inp, crnn_state[tt - 1]) h_crnn[tt] = tf.slice( crnn_state[tt], [0, crnn_dim], [-1, crnn_dim]) ctrl_out = cmlp(h_crnn[tt])[-1] elif ctrl_rnn_inp_struct == 'attn': crnn_inp = tf.reshape( h_ccnn_last, [-1, glimpse_map_dim, glimpse_feat_dim]) crnn_state[tt] = [None] * (num_ctrl_rnn_iter + 1) crnn_g_i[tt] = [None] * num_ctrl_rnn_iter crnn_g_f[tt] = [None] * num_ctrl_rnn_iter crnn_g_o[tt] = [None] * num_ctrl_rnn_iter h_crnn[tt] = [None] * num_ctrl_rnn_iter crnn_state[tt][-1] = tf.zeros(tf.pack([num_ex, crnn_dim * 2])) crnn_glimpse_map[tt] = [None] * num_ctrl_rnn_iter crnn_glimpse_map[tt][0] = tf.ones( tf.pack([num_ex, glimpse_map_dim, 1])) / glimpse_map_dim # Inner glimpse RNN for tt2 in xrange(num_ctrl_rnn_iter): crnn_glimpse = tf.reduce_sum( crnn_inp * crnn_glimpse_map[tt][tt2], [1]) crnn_state[tt][tt2], crnn_g_i[tt][tt2], crnn_g_f[tt][tt2], \ crnn_g_o[tt][tt2] = \ crnn_cell(crnn_glimpse, crnn_state[tt][tt2 - 1]) h_crnn[tt][tt2] = tf.slice( crnn_state[tt][tt2], [0, crnn_dim], [-1, crnn_dim]) h_gmlp = gmlp(h_crnn[tt][tt2]) if tt2 < num_ctrl_rnn_iter - 1: crnn_glimpse_map[tt][ tt2 + 1] = tf.expand_dims(h_gmlp[-1], 2) ctrl_out = cmlp(h_crnn[tt][-1])[-1] attn_ctr_norm[tt] = tf.slice(ctrl_out, [0, 0], [-1, 2]) attn_lg_size[tt] = tf.slice(ctrl_out, [0, 2], [-1, 2]) # Restrict to (-1, 1), (-inf, 0) if squash_ctrl_params: attn_ctr_norm[tt] = tf.tanh(attn_ctr_norm[tt]) attn_lg_size[tt] = -tf.nn.softplus(attn_lg_size[tt]) attn_ctr[tt], attn_size[tt] = base.get_unnormalized_attn( attn_ctr_norm[tt], attn_lg_size[tt], inp_height, inp_width) attn_lg_var[tt] = tf.zeros(tf.pack([num_ex, 2])) attn_box_lg_gamma[tt] = tf.slice(ctrl_out, [0, 7], [-1, 1]) attn_box_gamma[tt] = tf.reshape(tf.exp( attn_box_lg_gamma[tt]), [-1, 1, 1, 1]) attn_top_left[tt], attn_bot_right[tt] = base.get_box_coord( attn_ctr[tt], attn_size[tt]) # Initial filters (predicted) filter_y = base.get_gaussian_filter( attn_ctr[tt][:, 0], attn_size[tt][:, 0], attn_lg_var[tt][:, 0], inp_height, filter_height) filter_x = base.get_gaussian_filter( attn_ctr[tt][:, 1], attn_size[tt][:, 1], attn_lg_var[tt][:, 1], inp_width, filter_width) filter_y_inv = tf.transpose(filter_y, [0, 2, 1]) filter_x_inv = tf.transpose(filter_x, [0, 2, 1]) # Attention box attn_box[tt] = base.extract_patch( const_ones * attn_box_gamma[tt], filter_y_inv, filter_x_inv, 1) attn_box[tt] = tf.sigmoid(attn_box[tt] + attn_box_beta) attn_box[tt] = tf.reshape(attn_box[tt], [-1, 1, inp_height, inp_width]) if fixed_order: _y_out = tf.expand_dims(y_gt[:, tt, :, :], 3) else: iou_soft_box[tt] = base.f_inter( attn_box[tt], attn_box_gt) / \ base.f_union(attn_box[tt], attn_box_gt, eps=1e-5) grd_match = base.f_greedy_match( iou_soft_box[tt], grd_match_cum) grd_match = tf.expand_dims(tf.expand_dims(grd_match, 2), 3) _y_out = tf.expand_dims(tf.reduce_sum(grd_match * y_gt, 1), 3) # Add independent uniform noise to groundtruth. _noise = tf.random_uniform( tf.pack([num_ex, inp_height, inp_width, 1]), 0, 0.3) _y_out = _y_out - _y_out * _noise canvas = tf.stop_gradient(tf.maximum(_y_out, canvas)) # canvas += tf.stop_gradient(_y_out) # Scoring network if ctrl_rnn_inp_struct == 'dense': s_out[tt] = smlp(h_crnn[tt])[-1] elif ctrl_rnn_inp_struct == 'attn': s_out[tt] = smlp(h_crnn[tt][-1])[-1] ######################### # Model outputs ######################### s_out = tf.concat(1, s_out) model['s_out'] = s_out attn_box = tf.concat(1, attn_box) model['attn_box'] = attn_box attn_top_left = tf.concat(1, [tf.expand_dims(tmp, 1) for tmp in attn_top_left]) attn_bot_right = tf.concat(1, [tf.expand_dims(tmp, 1) for tmp in attn_bot_right]) attn_ctr = tf.concat(1, [tf.expand_dims(tmp, 1) for tmp in attn_ctr]) attn_size = tf.concat(1, [tf.expand_dims(tmp, 1) for tmp in attn_size]) model['attn_top_left'] = attn_top_left model['attn_bot_right'] = attn_bot_right model['attn_ctr'] = attn_ctr model['attn_size'] = attn_size model['attn_ctr_norm_gt'] = attn_ctr_norm_gt model['attn_lg_size_gt'] = attn_lg_size_gt model['attn_top_left_gt'] = attn_top_left_gt model['attn_bot_right_gt'] = attn_bot_right_gt model['attn_box_gt'] = attn_box_gt attn_ctr_norm = tf.concat(1, [tf.expand_dims(tmp, 1) for tmp in attn_ctr_norm]) attn_lg_size = tf.concat(1, [tf.expand_dims(tmp, 1) for tmp in attn_lg_size]) model['attn_ctr_norm'] = attn_ctr_norm model['attn_lg_size'] = attn_lg_size attn_params = tf.concat(2, [attn_ctr_norm, attn_lg_size]) attn_params_gt = tf.concat(2, [attn_ctr_norm_gt, attn_lg_size_gt]) ######################### # Loss function ######################### y_gt_shape = tf.shape(y_gt) num_ex_f = tf.to_float(y_gt_shape[0]) max_num_obj = tf.to_float(y_gt_shape[1]) ############################ # Box loss ############################ if fixed_order: # [B, T] for fixed order. iou_soft_box = base.f_iou(attn_box, attn_box_gt, pairwise=False) else: # [B, T, T] for matching. iou_soft_box = tf.concat(1, [tf.expand_dims(iou_soft_box[tt], 1) for tt in xrange(timespan)]) identity_match = base.get_identity_match(num_ex, timespan, s_gt) if fixed_order: match_box = identity_match else: match_box = base.f_segm_match(iou_soft_box, s_gt) model['match_box'] = match_box match_sum_box = tf.reduce_sum(match_box, reduction_indices=[2]) match_count_box = tf.reduce_sum( match_sum_box, reduction_indices=[1]) match_count_box = tf.maximum(1.0, match_count_box) # [B] if fixed order, [B, T] if matching. if fixed_order: iou_soft_box_mask = iou_soft_box else: iou_soft_box_mask = tf.reduce_sum(iou_soft_box * match_box, [1]) iou_soft_box = tf.reduce_sum(iou_soft_box_mask, [1]) iou_soft_box = tf.reduce_sum( iou_soft_box / match_count_box) / num_ex_f if box_loss_fn == 'mse': box_loss = base.f_match_loss( attn_params, attn_params_gt, match_box, timespan, base.f_squared_err, model=model) elif box_loss_fn == 'huber': box_loss = base.f_match_loss( attn_params, attn_params_gt, match_box, timespan, base.f_huber) if box_loss_fn == 'iou': box_loss = -iou_soft_box elif box_loss_fn == 'wt_iou': box_loss = -wt_iou_soft_box elif box_loss_fn == 'wt_cov': box_loss = -base.f_weighted_coverage(iou_soft_box, box_map_gt) elif box_loss_fn == 'bce': box_loss = base.f_match_loss( box_map, box_map_gt, match_box, timespan, base.f_bce) else: raise Exception('Unknown box_loss_fn: {}'.format(box_loss_fn)) model['box_loss'] = box_loss box_loss_coeff = tf.constant(1.0) model['box_loss_coeff'] = box_loss_coeff tf.add_to_collection('losses', box_loss_coeff * box_loss) #################### # Score loss #################### conf_loss = base.f_conf_loss( s_out, match_box, timespan, use_cum_min=True) model['conf_loss'] = conf_loss conf_loss_coeff = tf.constant(1.0) tf.add_to_collection('losses', conf_loss_coeff * conf_loss) #################### # Total loss #################### total_loss = tf.add_n(tf.get_collection('losses'), name='total_loss') model['loss'] = total_loss #################### # Optimizer #################### learn_rate = tf.train.exponential_decay( base_learn_rate, global_step, steps_per_learn_rate_decay, learn_rate_decay, staircase=True) model['learn_rate'] = learn_rate eps = 1e-7 train_step = GradientClipOptimizer( tf.train.AdamOptimizer(learn_rate, epsilon=eps), clip=clip_gradient).minimize(total_loss, global_step=global_step) model['train_step'] = train_step #################### # Glimpse #################### # T * T2 * [B, H' * W'] => [B, T, T2, H', W'] if ctrl_rnn_inp_struct == 'attn': crnn_glimpse_map = tf.concat( 1, [tf.expand_dims(tf.concat( 1, [tf.expand_dims(crnn_glimpse_map[tt][tt2], 1) for tt2 in xrange(num_ctrl_rnn_iter)]), 1) for tt in xrange(timespan)]) crnn_glimpse_map = tf.reshape( crnn_glimpse_map, [-1, timespan, num_ctrl_rnn_iter, crnn_h, crnn_w]) model['ctrl_rnn_glimpse_map'] = crnn_glimpse_map return model
def get_model(opt, device='/cpu:0'): """The attention model""" model = {} timespan = opt['timespan'] inp_height = opt['inp_height'] inp_width = opt['inp_width'] inp_depth = opt['inp_depth'] padding = opt['padding'] filter_height = opt['filter_height'] filter_width = opt['filter_width'] ctrl_cnn_filter_size = opt['ctrl_cnn_filter_size'] ctrl_cnn_depth = opt['ctrl_cnn_depth'] ctrl_cnn_pool = opt['ctrl_cnn_pool'] ctrl_rnn_hid_dim = opt['ctrl_rnn_hid_dim'] num_ctrl_mlp_layers = opt['num_ctrl_mlp_layers'] ctrl_mlp_dim = opt['ctrl_mlp_dim'] # New parameters for double attention. num_ctrl_rnn_iter = opt['num_ctrl_rnn_iter'] num_glimpse_mlp_layers = opt['num_glimpse_mlp_layers'] attn_cnn_filter_size = opt['attn_cnn_filter_size'] attn_cnn_depth = opt['attn_cnn_depth'] attn_cnn_pool = opt['attn_cnn_pool'] attn_dcnn_filter_size = opt['attn_dcnn_filter_size'] attn_dcnn_depth = opt['attn_dcnn_depth'] attn_dcnn_pool = opt['attn_dcnn_pool'] attn_rnn_hid_dim = opt['attn_rnn_hid_dim'] mlp_dropout_ratio = opt['mlp_dropout'] num_attn_mlp_layers = opt['num_attn_mlp_layers'] attn_mlp_depth = opt['attn_mlp_depth'] attn_box_padding_ratio = opt['attn_box_padding_ratio'] wd = opt['weight_decay'] use_bn = opt['use_bn'] use_gt_attn = opt['use_gt_attn'] segm_loss_fn = opt['segm_loss_fn'] box_loss_fn = opt['box_loss_fn'] loss_mix_ratio = opt['loss_mix_ratio'] base_learn_rate = opt['base_learn_rate'] learn_rate_decay = opt['learn_rate_decay'] steps_per_learn_rate_decay = opt['steps_per_learn_rate_decay'] use_attn_rnn = opt['use_attn_rnn'] use_knob = opt['use_knob'] knob_base = opt['knob_base'] knob_decay = opt['knob_decay'] steps_per_knob_decay = opt['steps_per_knob_decay'] use_canvas = opt['use_canvas'] knob_box_offset = opt['knob_box_offset'] knob_segm_offset = opt['knob_segm_offset'] knob_use_timescale = opt['knob_use_timescale'] gt_selector = opt['gt_selector'] gt_box_ctr_noise = opt['gt_box_ctr_noise'] gt_box_pad_noise = opt['gt_box_pad_noise'] gt_segm_noise = opt['gt_segm_noise'] downsample_canvas = opt['downsample_canvas'] pretrain_cnn = opt['pretrain_cnn'] cnn_share_weights = opt['cnn_share_weights'] squash_ctrl_params = opt['squash_ctrl_params'] use_iou_box = opt['use_iou_box'] clip_gradient = opt['clip_gradient'] rnd_hflip = opt['rnd_hflip'] rnd_vflip = opt['rnd_vflip'] rnd_transpose = opt['rnd_transpose'] rnd_colour = opt['rnd_colour'] ############################ # Input definition ############################ with tf.device(get_device_fn(device)): # Input image, [B, H, W, D] x = tf.placeholder('float', [None, inp_height, inp_width, inp_depth]) x_shape = tf.shape(x) num_ex = x_shape[0] # Groundtruth segmentation, [B, T, H, W] y_gt = tf.placeholder('float', [None, timespan, inp_height, inp_width]) # Groundtruth confidence score, [B, T] s_gt = tf.placeholder('float', [None, timespan]) # Whether in training stage. phase_train = tf.placeholder('bool') phase_train_f = tf.to_float(phase_train) model['x'] = x model['y_gt'] = y_gt model['s_gt'] = s_gt model['phase_train'] = phase_train # Global step global_step = tf.Variable(0.0) ############################### # Random input transformation ############################### x, y_gt = img.random_transformation(x, y_gt, padding, phase_train, rnd_hflip=rnd_hflip, rnd_vflip=rnd_vflip, rnd_transpose=rnd_transpose, rnd_colour=rnd_colour) model['x_trans'] = x model['y_gt_trans'] = y_gt ############################ # Canvas: external memory ############################ if use_canvas: canvas = tf.zeros(tf.pack([num_ex, inp_height, inp_width, 1])) ccnn_inp_depth = inp_depth + 1 acnn_inp_depth = inp_depth + 1 else: ccnn_inp_depth = inp_depth acnn_inp_depth = inp_depth ############################ # Controller CNN definition ############################ ccnn_filters = ctrl_cnn_filter_size ccnn_nlayers = len(ccnn_filters) ccnn_channels = [ccnn_inp_depth] + ctrl_cnn_depth ccnn_pool = ctrl_cnn_pool ccnn_act = [tf.nn.relu] * ccnn_nlayers ccnn_use_bn = [use_bn] * ccnn_nlayers if pretrain_cnn: h5f = h5py.File(pretrain_cnn, 'r') ccnn_init_w = [{ 'w': h5f['cnn_w_{}'.format(ii)][:], 'b': h5f['cnn_b_{}'.format(ii)][:] } for ii in xrange(ccnn_nlayers)] ccnn_frozen = True else: ccnn_init_w = None ccnn_frozen = None ccnn = nn.cnn(ccnn_filters, ccnn_channels, ccnn_pool, ccnn_act, ccnn_use_bn, phase_train=phase_train, wd=wd, scope='ctrl_cnn', model=model, init_weights=ccnn_init_w, frozen=ccnn_frozen) h_ccnn = [None] * timespan ############################ # Controller RNN definition ############################ ccnn_subsample = np.array(ccnn_pool).prod() crnn_h = inp_height / ccnn_subsample crnn_w = inp_width / ccnn_subsample crnn_dim = ctrl_rnn_hid_dim canvas_dim = inp_height * inp_width / (ccnn_subsample**2) glimpse_map_dim = crnn_h * crnn_w glimpse_feat_dim = ccnn_channels[-1] # crnn_inp_dim = crnn_h * crnn_w * ccnn_channels[-1] crnn_state = [None] * (timespan + 1) crnn_glimpse_map = [None] * timespan crnn_g_i = [None] * timespan crnn_g_f = [None] * timespan crnn_g_o = [None] * timespan h_crnn = [None] * timespan crnn_cell = nn.lstm(glimpse_feat_dim, crnn_dim, wd=wd, scope='ctrl_lstm', model=model) ############################ # Glimpse MLP definition ############################ gmlp_dims = [crnn_dim] * num_glimpse_mlp_layers + [glimpse_map_dim] gmlp_act = [tf.nn.relu] * \ (num_glimpse_mlp_layers - 1) + [tf.nn.softmax] gmlp_dropout = None gmlp = nn.mlp(gmlp_dims, gmlp_act, add_bias=True, dropout_keep=gmlp_dropout, phase_train=phase_train, wd=wd, scope='glimpse_mlp', model=model) ############################ # Controller MLP definition ############################ cmlp_dims = [crnn_dim] + [ctrl_mlp_dim] * \ (num_ctrl_mlp_layers - 1) + [9] cmlp_act = [tf.nn.relu] * (num_ctrl_mlp_layers - 1) + [None] cmlp_dropout = None # cmlp_dropout = [1.0 - mlp_dropout_ratio] * num_ctrl_mlp_layers cmlp = nn.mlp(cmlp_dims, cmlp_act, add_bias=True, dropout_keep=cmlp_dropout, phase_train=phase_train, wd=wd, scope='ctrl_mlp', model=model) ############################ # Attention CNN definition ############################ acnn_filters = attn_cnn_filter_size acnn_nlayers = len(acnn_filters) acnn_channels = [acnn_inp_depth] + attn_cnn_depth acnn_pool = attn_cnn_pool acnn_act = [tf.nn.relu] * acnn_nlayers acnn_use_bn = [use_bn] * acnn_nlayers if cnn_share_weights: ccnn_shared_weights = [] for ii in xrange(ccnn_nlayers): ccnn_shared_weights.append({ 'w': model['ctrl_cnn_w_{}'.format(ii)], 'b': model['ctrl_cnn_b_{}'.format(ii)] }) else: ccnn_shared_weights = None acnn = nn.cnn(acnn_filters, acnn_channels, acnn_pool, acnn_act, acnn_use_bn, phase_train=phase_train, wd=wd, scope='attn_cnn', model=model, shared_weights=ccnn_shared_weights) x_patch = [None] * timespan h_acnn = [None] * timespan h_acnn_last = [None] * timespan ############################ # Attention RNN definition ############################ acnn_subsample = np.array(acnn_pool).prod() arnn_h = filter_height / acnn_subsample arnn_w = filter_width / acnn_subsample if use_attn_rnn: arnn_dim = attn_rnn_hid_dim arnn_inp_dim = arnn_h * arnn_w * acnn_channels[-1] arnn_state = [None] * (timespan + 1) arnn_g_i = [None] * timespan arnn_g_f = [None] * timespan arnn_g_o = [None] * timespan arnn_state[-1] = tf.zeros(tf.pack([num_ex, arnn_dim * 2])) arnn_cell = nn.lstm(arnn_inp_dim, arnn_dim, wd=wd, scope='attn_lstm') amlp_inp_dim = arnn_dim else: amlp_inp_dim = arnn_h * arnn_w * acnn_channels[-1] ############################ # Attention MLP definition ############################ core_depth = attn_mlp_depth core_dim = arnn_h * arnn_w * core_depth amlp_dims = [amlp_inp_dim] + [core_dim] * num_attn_mlp_layers amlp_act = [tf.nn.relu] * num_attn_mlp_layers amlp_dropout = None # amlp_dropout = [1.0 - mlp_dropout_ratio] * num_attn_mlp_layers amlp = nn.mlp(amlp_dims, amlp_act, dropout_keep=amlp_dropout, phase_train=phase_train, wd=wd, scope='attn_mlp', model=model) # DCNN [B, RH, RW, MD] => [B, A, A, 1] adcnn_filters = attn_dcnn_filter_size adcnn_nlayers = len(adcnn_filters) adcnn_unpool = attn_dcnn_pool adcnn_act = [tf.nn.relu] * adcnn_nlayers adcnn_channels = [attn_mlp_depth] + attn_dcnn_depth adcnn_use_bn = [use_bn] * adcnn_nlayers adcnn_skip_ch = [0] + acnn_channels[::-1][1:] adcnn = nn.dcnn(adcnn_filters, adcnn_channels, adcnn_unpool, adcnn_act, use_bn=adcnn_use_bn, skip_ch=adcnn_skip_ch, phase_train=phase_train, wd=wd, model=model, scope='attn_dcnn') h_adcnn = [None] * timespan ########################## # Score MLP definition ########################## smlp = nn.mlp([crnn_dim, 1], [tf.sigmoid], wd=wd, scope='score_mlp', model=model) s_out = [None] * timespan ########################## # Attention box ########################## attn_ctr_norm = [None] * timespan attn_lg_size = [None] * timespan attn_ctr = [None] * timespan attn_size = [None] * timespan attn_lg_var = [None] * timespan attn_lg_gamma = [None] * timespan attn_gamma = [None] * timespan attn_box_lg_gamma = [None] * timespan attn_top_left = [None] * timespan attn_bot_right = [None] * timespan attn_box = [None] * timespan iou_soft_box = [None] * timespan const_ones = tf.ones(tf.pack([num_ex, filter_height, filter_width, 1])) attn_box_beta = tf.constant([-5.0]) attn_box_gamma = [None] * timespan ############################# # Groundtruth attention box ############################# # [B, T, 2] attn_ctr_gt, attn_size_gt, attn_lg_var_gt, attn_box_gt, \ attn_top_left_gt, attn_bot_right_gt = \ base.get_gt_attn(y_gt, padding_ratio=attn_box_padding_ratio, center_shift_ratio=0.0) attn_ctr_gt_noise, attn_size_gt_noise, attn_lg_var_gt_noise, \ attn_box_gt_noise, \ attn_top_left_gt_noise, attn_bot_right_gt_noise = \ base.get_gt_attn(y_gt, padding_ratio=tf.random_uniform( tf.pack([num_ex, timespan, 1]), attn_box_padding_ratio - gt_box_pad_noise, attn_box_padding_ratio + gt_box_pad_noise), center_shift_ratio=tf.random_uniform( tf.pack([num_ex, timespan, 2]), -gt_box_ctr_noise, gt_box_ctr_noise)) ########################## # Groundtruth mix ########################## grd_match_cum = tf.zeros(tf.pack([num_ex, timespan])) # Add a bias on every entry so there is no duplicate match # [1, N] iou_bias_eps = 1e-7 iou_bias = tf.expand_dims( tf.to_float(tf.reverse(tf.range(timespan), [True])) * iou_bias_eps, 0) # Scale mix ratio on different timesteps. gt_knob_time_scale = tf.reshape( 1.0 + tf.log(1.0 + tf.to_float(tf.range(timespan)) * 3.0 * float(knob_use_timescale)), [1, timespan, 1]) # Mix in groundtruth box. global_step_box = tf.maximum(0.0, global_step - knob_box_offset) gt_knob_prob_box = tf.train.exponential_decay(knob_base, global_step_box, steps_per_knob_decay, knob_decay, staircase=False) gt_knob_prob_box = tf.minimum(1.0, gt_knob_prob_box * gt_knob_time_scale) gt_knob_box = tf.to_float( tf.random_uniform(tf.pack([num_ex, timespan, 1]), 0, 1.0) <= gt_knob_prob_box) model['gt_knob_prob_box'] = gt_knob_prob_box[0, 0, 0] # Mix in groundtruth segmentation. global_step_segm = tf.maximum(0.0, global_step - knob_segm_offset) gt_knob_prob_segm = tf.train.exponential_decay(knob_base, global_step_segm, steps_per_knob_decay, knob_decay, staircase=False) gt_knob_prob_segm = tf.minimum(1.0, gt_knob_prob_segm * gt_knob_time_scale) gt_knob_segm = tf.to_float( tf.random_uniform(tf.pack([num_ex, timespan, 1]), 0, 1.0) <= gt_knob_prob_segm) model['gt_knob_prob_segm'] = gt_knob_prob_segm[0, 0, 0] ########################## # Segmentation output ########################## y_out = [None] * timespan y_out_lg_gamma = [None] * timespan y_out_beta = tf.constant([-5.0]) ########################## # Computation graph ########################## if not use_canvas: h_ccnn = ccnn(x) for tt in xrange(timespan): # Controller CNN [B, H, W, D] => [B, RH1, RW1, RD1] if use_canvas: ccnn_inp = tf.concat(3, [x, canvas]) acnn_inp = ccnn_inp h_ccnn[tt] = ccnn(ccnn_inp) _h_ccnn = h_ccnn[tt] else: ccnn_inp = x acnn_inp = x _h_ccnn = h_ccnn h_ccnn_last = _h_ccnn[-1] # crnn_inp = tf.reshape(h_ccnn_last, [-1, crnn_inp_dim]) crnn_inp = tf.reshape(h_ccnn_last, [-1, glimpse_map_dim, glimpse_feat_dim]) crnn_state[tt] = [None] * (num_ctrl_rnn_iter + 1) crnn_g_i[tt] = [None] * num_ctrl_rnn_iter crnn_g_f[tt] = [None] * num_ctrl_rnn_iter crnn_g_o[tt] = [None] * num_ctrl_rnn_iter h_crnn[tt] = [None] * num_ctrl_rnn_iter crnn_state[tt][-1] = tf.zeros(tf.pack([num_ex, crnn_dim * 2])) # if tt == 0: # crnn_state[tt][-1] = tf.zeros(tf.pack([num_ex, crnn_dim * 2])) # else: # crnn_state[tt][-1] = crnn_state[tt - 1][num_ctrl_rnn_iter - 1] crnn_glimpse_map[tt] = [None] * num_ctrl_rnn_iter crnn_glimpse_map[tt][0] = tf.ones( tf.pack([num_ex, glimpse_map_dim, 1])) / glimpse_map_dim for tt2 in xrange(num_ctrl_rnn_iter): crnn_glimpse = tf.reduce_sum( crnn_inp * crnn_glimpse_map[tt][tt2], [1]) crnn_state[tt][tt2], crnn_g_i[tt][tt2], crnn_g_f[tt][tt2], \ crnn_g_o[tt][tt2] = \ crnn_cell(crnn_glimpse, crnn_state[tt][tt2 - 1]) h_crnn[tt][tt2] = tf.slice(crnn_state[tt][tt2], [0, crnn_dim], [-1, crnn_dim]) h_gmlp = gmlp(h_crnn[tt][tt2]) if tt2 < num_ctrl_rnn_iter - 1: crnn_glimpse_map[tt][tt2 + 1] = tf.expand_dims( h_gmlp[-1], 2) ctrl_out = cmlp(h_crnn[tt][-1])[-1] attn_ctr_norm[tt] = tf.slice(ctrl_out, [0, 0], [-1, 2]) attn_lg_size[tt] = tf.slice(ctrl_out, [0, 2], [-1, 2]) # Restrict to (-1, 1), (-inf, 0) if squash_ctrl_params: attn_ctr_norm[tt] = tf.tanh(attn_ctr_norm[tt]) attn_lg_size[tt] = -tf.nn.softplus(attn_lg_size[tt]) attn_ctr[tt], attn_size[tt] = base.get_unnormalized_attn( attn_ctr_norm[tt], attn_lg_size[tt], inp_height, inp_width) attn_lg_var[tt] = tf.zeros(tf.pack([num_ex, 2])) attn_lg_gamma[tt] = tf.slice(ctrl_out, [0, 6], [-1, 1]) attn_box_lg_gamma[tt] = tf.slice(ctrl_out, [0, 7], [-1, 1]) y_out_lg_gamma[tt] = tf.slice(ctrl_out, [0, 8], [-1, 1]) attn_gamma[tt] = tf.reshape(tf.exp(attn_lg_gamma[tt]), [-1, 1, 1, 1]) attn_box_gamma[tt] = tf.reshape(tf.exp(attn_box_lg_gamma[tt]), [-1, 1, 1, 1]) y_out_lg_gamma[tt] = tf.reshape(y_out_lg_gamma[tt], [-1, 1, 1, 1]) # Initial filters (predicted) filter_y = get_gaussian_filter(attn_ctr[tt][:, 0], attn_size[tt][:, 0], attn_lg_var[tt][:, 0], inp_height, filter_height) filter_x = get_gaussian_filter(attn_ctr[tt][:, 1], attn_size[tt][:, 1], attn_lg_var[tt][:, 1], inp_width, filter_width) filter_y_inv = tf.transpose(filter_y, [0, 2, 1]) filter_x_inv = tf.transpose(filter_x, [0, 2, 1]) # Attention box if use_iou_box: _idx_map = get_idx_map(tf.pack([num_ex, inp_height, inp_width])) attn_top_left[tt], attn_bot_right[tt] = get_box_coord( attn_ctr[tt], attn_size[tt]) attn_box[tt] = get_filled_box_idx(_idx_map, attn_top_left[tt], attn_bot_right[tt]) attn_box[tt] = tf.reshape(attn_box[tt], [-1, 1, inp_height, inp_width]) else: attn_box[tt] = extract_patch(const_ones * attn_box_gamma[tt], filter_y_inv, filter_x_inv, 1) attn_box[tt] = tf.sigmoid(attn_box[tt] + attn_box_beta) attn_box[tt] = tf.reshape(attn_box[tt], [-1, 1, inp_height, inp_width]) # Here is the knob kick in GT bbox. if use_knob: # IOU [B, 1, T] # [B, 1, H, W] * [B, T, H, W] = [B, T] if use_iou_box: _top_left = tf.expand_dims(attn_top_left[tt], 1) _bot_right = tf.expand_dims(attn_bot_right[tt], 1) iou_soft_box[tt] = f_iou_box(_top_left, _bot_right, attn_top_left_gt, attn_bot_right_gt) iou_soft_box[tt] += iou_bias else: iou_soft_box[tt] = f_inter(attn_box[tt], attn_box_gt) / \ f_union(attn_box[tt], attn_box_gt, eps=1e-5) grd_match = f_greedy_match(iou_soft_box[tt], grd_match_cum) if gt_selector == 'greedy_match': # Add in the cumulative matching to not double count. grd_match_cum += grd_match # [B, T, 1] grd_match = tf.expand_dims(grd_match, 2) attn_ctr_gt_match = tf.reduce_sum( grd_match * attn_ctr_gt_noise, 1) attn_size_gt_match = tf.reduce_sum( grd_match * attn_size_gt_noise, 1) _gt_knob_box = gt_knob_box attn_ctr[tt] = phase_train_f * _gt_knob_box[:, tt, 0: 1] * \ attn_ctr_gt_match + \ (1 - phase_train_f * _gt_knob_box[:, tt, 0: 1]) * \ attn_ctr[tt] attn_size[tt] = phase_train_f * _gt_knob_box[:, tt, 0: 1] * \ attn_size_gt_match + \ (1 - phase_train_f * _gt_knob_box[:, tt, 0: 1]) * \ attn_size[tt] attn_top_left[tt], attn_bot_right[tt] = get_box_coord( attn_ctr[tt], attn_size[tt]) filter_y = get_gaussian_filter(attn_ctr[tt][:, 0], attn_size[tt][:, 0], attn_lg_var[tt][:, 0], inp_height, filter_height) filter_x = get_gaussian_filter(attn_ctr[tt][:, 1], attn_size[tt][:, 1], attn_lg_var[tt][:, 1], inp_width, filter_width) filter_y_inv = tf.transpose(filter_y, [0, 2, 1]) filter_x_inv = tf.transpose(filter_x, [0, 2, 1]) # Attended patch [B, A, A, D] x_patch[tt] = attn_gamma[tt] * extract_patch( acnn_inp, filter_y, filter_x, acnn_inp_depth) # CNN [B, A, A, D] => [B, RH2, RW2, RD2] h_acnn[tt] = acnn(x_patch[tt]) h_acnn_last[tt] = h_acnn[tt][-1] if use_attn_rnn: # RNN [B, T, R2] arnn_inp = tf.reshape(h_acnn_last[tt], [-1, arnn_inp_dim]) arnn_state[tt], arnn_g_i[tt], arnn_g_f[tt], arnn_g_o[tt] = \ arnn_cell(arnn_inp, arnn_state[tt - 1]) # Scoring network s_out[tt] = smlp(h_crnn[tt][-1])[-1] # Dense segmentation network [B, R] => [B, M] if use_attn_rnn: h_arnn = tf.slice(arnn_state[tt], [0, arnn_dim], [-1, arnn_dim]) amlp_inp = h_arnn else: amlp_inp = h_acnn_last[tt] amlp_inp = tf.reshape(amlp_inp, [-1, amlp_inp_dim]) h_core = amlp(amlp_inp)[-1] h_core = tf.reshape(h_core, [-1, arnn_h, arnn_w, attn_mlp_depth]) # DCNN skip = [None] + h_acnn[tt][::-1][1:] + [x_patch[tt]] h_adcnn[tt] = adcnn(h_core, skip=skip) # Output y_out[tt] = extract_patch(h_adcnn[tt][-1], filter_y_inv, filter_x_inv, 1) y_out[tt] = tf.exp(y_out_lg_gamma[tt]) * y_out[tt] + y_out_beta y_out[tt] = tf.sigmoid(y_out[tt]) y_out[tt] = tf.reshape(y_out[tt], [-1, 1, inp_height, inp_width]) # Here is the knob kick in GT segmentations at this timestep. # [B, N, 1, 1] if use_canvas: if use_knob: _gt_knob_segm = tf.expand_dims( tf.expand_dims(gt_knob_segm[:, tt, 0:1], 2), 3) # [B, N, 1, 1] grd_match = tf.expand_dims(grd_match, 3) _y_out = tf.expand_dims(tf.reduce_sum(grd_match * y_gt, 1), 3) # Add independent uniform noise to groundtruth. _noise = tf.random_uniform( tf.pack([num_ex, inp_height, inp_width, 1]), 0, gt_segm_noise) _y_out = _y_out - _y_out * _noise _y_out = phase_train_f * _gt_knob_segm * _y_out + \ (1 - phase_train_f * _gt_knob_segm) * \ tf.reshape(y_out[tt], [-1, inp_height, inp_width, 1]) else: _y_out = tf.reshape(y_out[tt], [-1, inp_height, inp_width, 1]) canvas += tf.stop_gradient(_y_out) ######################### # Model outputs ######################### s_out = tf.concat(1, s_out) model['s_out'] = s_out y_out = tf.concat(1, y_out) model['y_out'] = y_out attn_box = tf.concat(1, attn_box) model['attn_box'] = attn_box x_patch = tf.concat( 1, [tf.expand_dims(x_patch[tt], 1) for tt in xrange(timespan)]) model['x_patch'] = x_patch ######################### # Loss function ######################### y_gt_shape = tf.shape(y_gt) num_ex_f = tf.to_float(y_gt_shape[0]) max_num_obj = tf.to_float(y_gt_shape[1]) ############################ # Box loss ############################ if use_knob: iou_soft_box = tf.concat(1, [ tf.expand_dims(iou_soft_box[tt], 1) for tt in xrange(timespan) ]) else: iou_soft_box = f_iou(attn_box, attn_box_gt, timespan, pairwise=True) model['iou_soft_box'] = iou_soft_box model['attn_box_gt'] = attn_box_gt match_box = f_segm_match(iou_soft_box, s_gt) model['match_box'] = match_box match_sum_box = tf.reduce_sum(match_box, reduction_indices=[2]) match_count_box = tf.reduce_sum(match_sum_box, reduction_indices=[1]) match_count_box = tf.maximum(1.0, match_count_box) iou_soft_box_mask = tf.reduce_sum(iou_soft_box * match_box, [1]) iou_soft_box = tf.reduce_sum( tf.reduce_sum(iou_soft_box_mask, [1]) / match_count_box) / num_ex_f gt_wt_box = f_coverage_weight(attn_box_gt) wt_iou_soft_box = tf.reduce_sum( tf.reduce_sum(iou_soft_box_mask * gt_wt_box, [1]) / match_count_box) / num_ex_f if box_loss_fn == 'iou': box_loss = -iou_soft_box elif box_loss_fn == 'wt_iou': box_loss = -wt_iou_soft_box elif box_loss_fn == 'wt_cov': box_loss = -f_weighted_coverage(iou_soft_box, attn_box_gt) elif box_loss_fn == 'bce': box_loss = f_match_bce(attn_box, attn_box_gt, match_box, timespan) else: raise Exception('Unknown box_loss_fn: {}'.format(box_loss_fn)) model['box_loss'] = box_loss box_loss_coeff = tf.constant(1.0) tf.add_to_collection('losses', box_loss_coeff * box_loss) ############################## # Segmentation loss ############################## # IoU (soft) iou_soft = f_iou(y_out, y_gt, timespan, pairwise=True) match = f_segm_match(iou_soft, s_gt) model['match'] = match match_sum = tf.reduce_sum(match, reduction_indices=[2]) match_count = tf.reduce_sum(match_sum, reduction_indices=[1]) match_count = tf.maximum(1.0, match_count) # Weighted coverage (soft) wt_cov_soft = f_weighted_coverage(iou_soft, y_gt) model['wt_cov_soft'] = wt_cov_soft unwt_cov_soft = f_unweighted_coverage(iou_soft, match_count) model['unwt_cov_soft'] = unwt_cov_soft # IOU (soft) iou_soft_mask = tf.reduce_sum(iou_soft * match, [1]) iou_soft = tf.reduce_sum( tf.reduce_sum(iou_soft_mask, [1]) / match_count) / num_ex_f model['iou_soft'] = iou_soft gt_wt = f_coverage_weight(y_gt) wt_iou_soft = tf.reduce_sum( tf.reduce_sum(iou_soft_mask * gt_wt, [1]) / match_count) / num_ex_f model['wt_iou_soft'] = wt_iou_soft if segm_loss_fn == 'iou': segm_loss = -iou_soft elif segm_loss_fn == 'wt_iou': segm_loss = -wt_iou_soft elif segm_loss_fn == 'wt_cov': segm_loss = -wt_cov_soft elif segm_loss_fn == 'bce': segm_loss = f_match_bce(y_out, y_gt, match, timespan) else: raise Exception('Unknown segm_loss_fn: {}'.format(segm_loss_fn)) model['segm_loss'] = segm_loss segm_loss_coeff = tf.constant(1.0) tf.add_to_collection('losses', segm_loss_coeff * segm_loss) #################### # Score loss #################### conf_loss = f_conf_loss(s_out, match, timespan, use_cum_min=True) model['conf_loss'] = conf_loss tf.add_to_collection('losses', loss_mix_ratio * conf_loss) #################### # Total loss #################### total_loss = tf.add_n(tf.get_collection('losses'), name='total_loss') model['loss'] = total_loss #################### # Optimizer #################### learn_rate = tf.train.exponential_decay(base_learn_rate, global_step, steps_per_learn_rate_decay, learn_rate_decay, staircase=True) model['learn_rate'] = learn_rate eps = 1e-7 train_step = GradientClipOptimizer(tf.train.AdamOptimizer(learn_rate, epsilon=eps), clip=clip_gradient).minimize( total_loss, global_step=global_step) model['train_step'] = train_step #################### # Statistics #################### # [B, M, N] * [B, M, N] => [B] * [B] => [1] y_out_hard = tf.to_float(y_out > 0.5) iou_hard = f_iou(y_out_hard, y_gt, timespan, pairwise=True) wt_cov_hard = f_weighted_coverage(iou_hard, y_gt) model['wt_cov_hard'] = wt_cov_hard unwt_cov_hard = f_unweighted_coverage(iou_hard, match_count) model['unwt_cov_hard'] = unwt_cov_hard # [B, T] iou_hard_mask = tf.reduce_sum(iou_hard * match, [1]) iou_hard = tf.reduce_sum( tf.reduce_sum(iou_hard_mask, [1]) / match_count) / num_ex_f model['iou_hard'] = iou_hard wt_iou_hard = tf.reduce_sum( tf.reduce_sum(iou_hard_mask * gt_wt, [1]) / match_count) / num_ex_f model['wt_iou_hard'] = wt_iou_hard dice = f_dice(y_out_hard, y_gt, timespan, pairwise=True) dice = tf.reduce_sum( tf.reduce_sum(dice * match, [1, 2]) / match_count) / num_ex_f model['dice'] = dice model['count_acc'] = f_count_acc(s_out, s_gt) model['dic'] = f_dic(s_out, s_gt, abs=False) model['dic_abs'] = f_dic(s_out, s_gt, abs=True) ################################ # Controller output statistics ################################ attn_top_left = tf.concat( 1, [tf.expand_dims(tmp, 1) for tmp in attn_top_left]) attn_bot_right = tf.concat( 1, [tf.expand_dims(tmp, 1) for tmp in attn_bot_right]) attn_ctr = tf.concat(1, [tf.expand_dims(tmp, 1) for tmp in attn_ctr]) attn_size = tf.concat(1, [tf.expand_dims(tmp, 1) for tmp in attn_size]) attn_lg_gamma = tf.concat( 1, [tf.expand_dims(tmp, 1) for tmp in attn_lg_gamma]) attn_box_lg_gamma = tf.concat( 1, [tf.expand_dims(tmp, 1) for tmp in attn_box_lg_gamma]) y_out_lg_gamma = tf.concat( 1, [tf.expand_dims(tmp, 1) for tmp in y_out_lg_gamma]) attn_lg_gamma_mean = tf.reduce_sum(attn_lg_gamma) / num_ex_f / timespan attn_box_lg_gamma_mean = tf.reduce_sum( attn_box_lg_gamma) / num_ex_f / timespan y_out_lg_gamma_mean = tf.reduce_sum( y_out_lg_gamma) / num_ex_f / timespan model['attn_ctr'] = attn_ctr model['attn_size'] = attn_size model['attn_top_left'] = attn_top_left model['attn_bot_right'] = attn_bot_right model['attn_lg_gamma_mean'] = attn_lg_gamma_mean model['attn_box_lg_gamma_mean'] = attn_box_lg_gamma_mean model['y_out_lg_gamma_mean'] = y_out_lg_gamma_mean return model
def get_model(opt, device='/cpu:0'): """The box model""" model = {} timespan = opt['timespan'] inp_height = opt['inp_height'] inp_width = opt['inp_width'] inp_depth = opt['inp_depth'] padding = opt['padding'] filter_height = opt['filter_height'] filter_width = opt['filter_width'] ctrl_cnn_filter_size = opt['ctrl_cnn_filter_size'] ctrl_cnn_depth = opt['ctrl_cnn_depth'] ctrl_cnn_pool = opt['ctrl_cnn_pool'] ctrl_rnn_hid_dim = opt['ctrl_rnn_hid_dim'] num_ctrl_mlp_layers = opt['num_ctrl_mlp_layers'] ctrl_mlp_dim = opt['ctrl_mlp_dim'] attn_box_padding_ratio = opt['attn_box_padding_ratio'] wd = opt['weight_decay'] use_bn = opt['use_bn'] box_loss_fn = opt['box_loss_fn'] base_learn_rate = opt['base_learn_rate'] learn_rate_decay = opt['learn_rate_decay'] steps_per_learn_rate_decay = opt['steps_per_learn_rate_decay'] pretrain_cnn = opt['pretrain_cnn'] squash_ctrl_params = opt['squash_ctrl_params'] clip_gradient = opt['clip_gradient'] fixed_order = opt['fixed_order'] ctrl_rnn_inp_struct = opt['ctrl_rnn_inp_struct'] # dense or attn num_ctrl_rnn_iter = opt['num_ctrl_rnn_iter'] num_glimpse_mlp_layers = opt['num_glimpse_mlp_layers'] rnd_hflip = opt['rnd_hflip'] rnd_vflip = opt['rnd_vflip'] rnd_transpose = opt['rnd_transpose'] rnd_colour = opt['rnd_colour'] ############################ # Input definition ############################ with tf.device(base.get_device_fn(device)): # Input image, [B, H, W, D] x = tf.placeholder('float', [None, inp_height, inp_width, inp_depth], name='x') x_shape = tf.shape(x) num_ex = x_shape[0] # Groundtruth segmentation, [B, T, H, W] y_gt = tf.placeholder('float', [None, timespan, inp_height, inp_width], name='y_gt') # Groundtruth confidence score, [B, T] s_gt = tf.placeholder('float', [None, timespan], name='s_gt') # Whether in training stage. phase_train = tf.placeholder('bool', name='phase_train') phase_train_f = tf.to_float(phase_train) model['x'] = x model['y_gt'] = y_gt model['s_gt'] = s_gt model['phase_train'] = phase_train # Global step global_step = tf.Variable(0.0, name='global_step') ############################### # Random input transformation ############################### x, y_gt = img.random_transformation(x, y_gt, padding, phase_train, rnd_hflip=rnd_hflip, rnd_vflip=rnd_vflip, rnd_transpose=rnd_transpose, rnd_colour=rnd_colour) model['x_trans'] = x model['y_gt_trans'] = y_gt ############################ # Canvas: external memory ############################ canvas = tf.zeros(tf.pack([num_ex, inp_height, inp_width, 1])) ccnn_inp_depth = inp_depth + 1 acnn_inp_depth = inp_depth + 1 ############################ # Controller CNN definition ############################ ccnn_filters = ctrl_cnn_filter_size ccnn_nlayers = len(ccnn_filters) ccnn_channels = [ccnn_inp_depth] + ctrl_cnn_depth ccnn_pool = ctrl_cnn_pool ccnn_act = [tf.nn.relu] * ccnn_nlayers ccnn_use_bn = [use_bn] * ccnn_nlayers if pretrain_cnn: log.info('Loading pretrained weights from {}'.format(pretrain_cnn)) h5f = h5py.File(pretrain_cnn, 'r') acnn_nlayers = 0 # Assuming acnn_nlayers is smaller than ccnn_nlayers. for ii in xrange(ccnn_nlayers): if 'attn_cnn_w_{}'.format(ii) in h5f: log.info('Loading attn_cnn_w_{}'.format(ii)) log.info('Loading attn_cnn_b_{}'.format(ii)) acnn_nlayers += 1 ccnn_init_w = [{ 'w': h5f['attn_cnn_w_{}'.format(ii)][:], 'b': h5f['attn_cnn_b_{}'.format(ii)][:] } for ii in xrange(acnn_nlayers)] for ii in xrange(acnn_nlayers): for tt in xrange(timespan): for w in ['beta', 'gamma']: ccnn_init_w[ii]['{}_{}'.format( w, tt)] = h5f['attn_cnn_{}_{}_{}'.format(ii, tt, w)][:] ccnn_frozen = [True] * acnn_nlayers for ii in xrange(acnn_nlayers, ccnn_nlayers): ccnn_init_w.append(None) ccnn_frozen.append(False) else: ccnn_init_w = None ccnn_frozen = None ccnn = nn.cnn(ccnn_filters, ccnn_channels, ccnn_pool, ccnn_act, ccnn_use_bn, phase_train=phase_train, wd=wd, scope='ctrl_cnn', model=model, init_weights=ccnn_init_w, frozen=ccnn_frozen) h_ccnn = [None] * timespan ############################ # Controller RNN definition ############################ ccnn_subsample = np.array(ccnn_pool).prod() crnn_h = inp_height / ccnn_subsample crnn_w = inp_width / ccnn_subsample crnn_dim = ctrl_rnn_hid_dim canvas_dim = inp_height * inp_width / (ccnn_subsample**2) glimpse_map_dim = crnn_h * crnn_w glimpse_feat_dim = ccnn_channels[-1] if ctrl_rnn_inp_struct == 'dense': crnn_inp_dim = crnn_h * crnn_w * ccnn_channels[-1] elif ctrl_rnn_inp_struct == 'attn': crnn_inp_dim = glimpse_feat_dim crnn_state = [None] * (timespan + 1) crnn_glimpse_map = [None] * timespan crnn_g_i = [None] * timespan crnn_g_f = [None] * timespan crnn_g_o = [None] * timespan h_crnn = [None] * timespan crnn_state[-1] = tf.zeros(tf.pack([num_ex, crnn_dim * 2])) crnn_cell = nn.lstm(crnn_inp_dim, crnn_dim, wd=wd, scope='ctrl_lstm', model=model) ############################ # Glimpse MLP definition ############################ gmlp_dims = [crnn_dim] * num_glimpse_mlp_layers + [glimpse_map_dim] gmlp_act = [tf.nn.relu] * \ (num_glimpse_mlp_layers - 1) + [tf.nn.softmax] gmlp_dropout = None gmlp = nn.mlp(gmlp_dims, gmlp_act, add_bias=True, dropout_keep=gmlp_dropout, phase_train=phase_train, wd=wd, scope='glimpse_mlp', model=model) ############################ # Controller MLP definition ############################ cmlp_dims = [crnn_dim] + [ctrl_mlp_dim] * \ (num_ctrl_mlp_layers - 1) + [9] cmlp_act = [tf.nn.relu] * (num_ctrl_mlp_layers - 1) + [None] cmlp_dropout = None cmlp = nn.mlp(cmlp_dims, cmlp_act, add_bias=True, dropout_keep=cmlp_dropout, phase_train=phase_train, wd=wd, scope='ctrl_mlp', model=model) ########################## # Score MLP definition ########################## smlp = nn.mlp([crnn_dim, 1], [tf.sigmoid], wd=wd, scope='score_mlp') s_out = [None] * timespan ########################## # Attention box ########################## attn_ctr_norm = [None] * timespan attn_lg_size = [None] * timespan attn_lg_var = [None] * timespan attn_ctr = [None] * timespan attn_size = [None] * timespan attn_top_left = [None] * timespan attn_bot_right = [None] * timespan attn_box = [None] * timespan attn_box_lg_gamma = [None] * timespan attn_box_gamma = [None] * timespan const_ones = tf.ones(tf.pack([num_ex, filter_height, filter_width, 1])) attn_box_beta = tf.constant([-5.0]) iou_soft_box = [None] * timespan ############################# # Groundtruth attention box ############################# attn_top_left_gt, attn_bot_right_gt, attn_box_gt = base.get_gt_box( y_gt, padding_ratio=attn_box_padding_ratio, center_shift_ratio=0.0) attn_ctr_gt, attn_size_gt = base.get_box_ctr_size( attn_top_left_gt, attn_bot_right_gt) attn_ctr_norm_gt = base.get_normalized_center(attn_ctr_gt, inp_height, inp_width) attn_lg_size_gt = base.get_normalized_size(attn_size_gt, inp_height, inp_width) ########################## # Groundtruth mix ########################## grd_match_cum = tf.zeros(tf.pack([num_ex, timespan])) ########################## # Computation graph ########################## for tt in xrange(timespan): # Controller CNN ccnn_inp = tf.concat(3, [x, canvas]) acnn_inp = ccnn_inp h_ccnn[tt] = ccnn(ccnn_inp) _h_ccnn = h_ccnn[tt] h_ccnn_last = _h_ccnn[-1] # Controller RNN [B, R1] if ctrl_rnn_inp_struct == 'dense': crnn_inp = tf.reshape(h_ccnn_last, [-1, crnn_inp_dim]) crnn_state[tt], crnn_g_i[tt], crnn_g_f[tt], crnn_g_o[tt] = \ crnn_cell(crnn_inp, crnn_state[tt - 1]) h_crnn[tt] = tf.slice(crnn_state[tt], [0, crnn_dim], [-1, crnn_dim]) ctrl_out = cmlp(h_crnn[tt])[-1] elif ctrl_rnn_inp_struct == 'attn': crnn_inp = tf.reshape(h_ccnn_last, [-1, glimpse_map_dim, glimpse_feat_dim]) crnn_state[tt] = [None] * (num_ctrl_rnn_iter + 1) crnn_g_i[tt] = [None] * num_ctrl_rnn_iter crnn_g_f[tt] = [None] * num_ctrl_rnn_iter crnn_g_o[tt] = [None] * num_ctrl_rnn_iter h_crnn[tt] = [None] * num_ctrl_rnn_iter crnn_state[tt][-1] = tf.zeros(tf.pack([num_ex, crnn_dim * 2])) crnn_glimpse_map[tt] = [None] * num_ctrl_rnn_iter crnn_glimpse_map[tt][0] = tf.ones( tf.pack([num_ex, glimpse_map_dim, 1])) / glimpse_map_dim # Inner glimpse RNN for tt2 in xrange(num_ctrl_rnn_iter): crnn_glimpse = tf.reduce_sum( crnn_inp * crnn_glimpse_map[tt][tt2], [1]) crnn_state[tt][tt2], crnn_g_i[tt][tt2], crnn_g_f[tt][tt2], \ crnn_g_o[tt][tt2] = \ crnn_cell(crnn_glimpse, crnn_state[tt][tt2 - 1]) h_crnn[tt][tt2] = tf.slice(crnn_state[tt][tt2], [0, crnn_dim], [-1, crnn_dim]) h_gmlp = gmlp(h_crnn[tt][tt2]) if tt2 < num_ctrl_rnn_iter - 1: crnn_glimpse_map[tt][tt2 + 1] = tf.expand_dims( h_gmlp[-1], 2) ctrl_out = cmlp(h_crnn[tt][-1])[-1] attn_ctr_norm[tt] = tf.slice(ctrl_out, [0, 0], [-1, 2]) attn_lg_size[tt] = tf.slice(ctrl_out, [0, 2], [-1, 2]) # Restrict to (-1, 1), (-inf, 0) if squash_ctrl_params: attn_ctr_norm[tt] = tf.tanh(attn_ctr_norm[tt]) attn_lg_size[tt] = -tf.nn.softplus(attn_lg_size[tt]) attn_ctr[tt], attn_size[tt] = base.get_unnormalized_attn( attn_ctr_norm[tt], attn_lg_size[tt], inp_height, inp_width) attn_lg_var[tt] = tf.zeros(tf.pack([num_ex, 2])) attn_box_lg_gamma[tt] = tf.slice(ctrl_out, [0, 7], [-1, 1]) attn_box_gamma[tt] = tf.reshape(tf.exp(attn_box_lg_gamma[tt]), [-1, 1, 1, 1]) attn_top_left[tt], attn_bot_right[tt] = base.get_box_coord( attn_ctr[tt], attn_size[tt]) # Initial filters (predicted) filter_y = base.get_gaussian_filter(attn_ctr[tt][:, 0], attn_size[tt][:, 0], attn_lg_var[tt][:, 0], inp_height, filter_height) filter_x = base.get_gaussian_filter(attn_ctr[tt][:, 1], attn_size[tt][:, 1], attn_lg_var[tt][:, 1], inp_width, filter_width) filter_y_inv = tf.transpose(filter_y, [0, 2, 1]) filter_x_inv = tf.transpose(filter_x, [0, 2, 1]) # Attention box attn_box[tt] = base.extract_patch(const_ones * attn_box_gamma[tt], filter_y_inv, filter_x_inv, 1) attn_box[tt] = tf.sigmoid(attn_box[tt] + attn_box_beta) attn_box[tt] = tf.reshape(attn_box[tt], [-1, 1, inp_height, inp_width]) if fixed_order: _y_out = tf.expand_dims(y_gt[:, tt, :, :], 3) else: iou_soft_box[tt] = base.f_inter( attn_box[tt], attn_box_gt) / \ base.f_union(attn_box[tt], attn_box_gt, eps=1e-5) grd_match = base.f_greedy_match(iou_soft_box[tt], grd_match_cum) grd_match = tf.expand_dims(tf.expand_dims(grd_match, 2), 3) _y_out = tf.expand_dims(tf.reduce_sum(grd_match * y_gt, 1), 3) # Add independent uniform noise to groundtruth. _noise = tf.random_uniform( tf.pack([num_ex, inp_height, inp_width, 1]), 0, 0.3) _y_out = _y_out - _y_out * _noise canvas = tf.stop_gradient(tf.maximum(_y_out, canvas)) # canvas += tf.stop_gradient(_y_out) # Scoring network if ctrl_rnn_inp_struct == 'dense': s_out[tt] = smlp(h_crnn[tt])[-1] elif ctrl_rnn_inp_struct == 'attn': s_out[tt] = smlp(h_crnn[tt][-1])[-1] ######################### # Model outputs ######################### s_out = tf.concat(1, s_out) model['s_out'] = s_out attn_box = tf.concat(1, attn_box) model['attn_box'] = attn_box attn_top_left = tf.concat( 1, [tf.expand_dims(tmp, 1) for tmp in attn_top_left]) attn_bot_right = tf.concat( 1, [tf.expand_dims(tmp, 1) for tmp in attn_bot_right]) attn_ctr = tf.concat(1, [tf.expand_dims(tmp, 1) for tmp in attn_ctr]) attn_size = tf.concat(1, [tf.expand_dims(tmp, 1) for tmp in attn_size]) model['attn_top_left'] = attn_top_left model['attn_bot_right'] = attn_bot_right model['attn_ctr'] = attn_ctr model['attn_size'] = attn_size model['attn_ctr_norm_gt'] = attn_ctr_norm_gt model['attn_lg_size_gt'] = attn_lg_size_gt model['attn_top_left_gt'] = attn_top_left_gt model['attn_bot_right_gt'] = attn_bot_right_gt model['attn_box_gt'] = attn_box_gt attn_ctr_norm = tf.concat( 1, [tf.expand_dims(tmp, 1) for tmp in attn_ctr_norm]) attn_lg_size = tf.concat( 1, [tf.expand_dims(tmp, 1) for tmp in attn_lg_size]) model['attn_ctr_norm'] = attn_ctr_norm model['attn_lg_size'] = attn_lg_size attn_params = tf.concat(2, [attn_ctr_norm, attn_lg_size]) attn_params_gt = tf.concat(2, [attn_ctr_norm_gt, attn_lg_size_gt]) ######################### # Loss function ######################### y_gt_shape = tf.shape(y_gt) num_ex_f = tf.to_float(y_gt_shape[0]) max_num_obj = tf.to_float(y_gt_shape[1]) ############################ # Box loss ############################ if fixed_order: # [B, T] for fixed order. iou_soft_box = base.f_iou(attn_box, attn_box_gt, pairwise=False) else: # [B, T, T] for matching. iou_soft_box = tf.concat(1, [ tf.expand_dims(iou_soft_box[tt], 1) for tt in xrange(timespan) ]) identity_match = base.get_identity_match(num_ex, timespan, s_gt) if fixed_order: match_box = identity_match else: match_box = base.f_segm_match(iou_soft_box, s_gt) model['match_box'] = match_box match_sum_box = tf.reduce_sum(match_box, reduction_indices=[2]) match_count_box = tf.reduce_sum(match_sum_box, reduction_indices=[1]) match_count_box = tf.maximum(1.0, match_count_box) # [B] if fixed order, [B, T] if matching. if fixed_order: iou_soft_box_mask = iou_soft_box else: iou_soft_box_mask = tf.reduce_sum(iou_soft_box * match_box, [1]) iou_soft_box = tf.reduce_sum(iou_soft_box_mask, [1]) iou_soft_box = tf.reduce_sum(iou_soft_box / match_count_box) / num_ex_f if box_loss_fn == 'mse': box_loss = base.f_match_loss(attn_params, attn_params_gt, match_box, timespan, base.f_squared_err, model=model) elif box_loss_fn == 'huber': box_loss = base.f_match_loss(attn_params, attn_params_gt, match_box, timespan, base.f_huber) if box_loss_fn == 'iou': box_loss = -iou_soft_box elif box_loss_fn == 'wt_iou': box_loss = -wt_iou_soft_box elif box_loss_fn == 'wt_cov': box_loss = -base.f_weighted_coverage(iou_soft_box, box_map_gt) elif box_loss_fn == 'bce': box_loss = base.f_match_loss(box_map, box_map_gt, match_box, timespan, base.f_bce) else: raise Exception('Unknown box_loss_fn: {}'.format(box_loss_fn)) model['box_loss'] = box_loss box_loss_coeff = tf.constant(1.0) model['box_loss_coeff'] = box_loss_coeff tf.add_to_collection('losses', box_loss_coeff * box_loss) #################### # Score loss #################### conf_loss = base.f_conf_loss(s_out, match_box, timespan, use_cum_min=True) model['conf_loss'] = conf_loss conf_loss_coeff = tf.constant(1.0) tf.add_to_collection('losses', conf_loss_coeff * conf_loss) #################### # Total loss #################### total_loss = tf.add_n(tf.get_collection('losses'), name='total_loss') model['loss'] = total_loss #################### # Optimizer #################### learn_rate = tf.train.exponential_decay(base_learn_rate, global_step, steps_per_learn_rate_decay, learn_rate_decay, staircase=True) model['learn_rate'] = learn_rate eps = 1e-7 train_step = GradientClipOptimizer(tf.train.AdamOptimizer(learn_rate, epsilon=eps), clip=clip_gradient).minimize( total_loss, global_step=global_step) model['train_step'] = train_step #################### # Glimpse #################### # T * T2 * [B, H' * W'] => [B, T, T2, H', W'] if ctrl_rnn_inp_struct == 'attn': crnn_glimpse_map = tf.concat(1, [ tf.expand_dims( tf.concat(1, [ tf.expand_dims(crnn_glimpse_map[tt][tt2], 1) for tt2 in xrange(num_ctrl_rnn_iter) ]), 1) for tt in xrange(timespan) ]) crnn_glimpse_map = tf.reshape( crnn_glimpse_map, [-1, timespan, num_ctrl_rnn_iter, crnn_h, crnn_w]) model['ctrl_rnn_glimpse_map'] = crnn_glimpse_map return model
def get_model(opt, device='/cpu:0'): model = {} inp_height = opt['inp_height'] inp_width = opt['inp_width'] inp_depth = opt['inp_depth'] cnn_filter_size = opt['cnn_filter_size'] cnn_depth = opt['cnn_depth'] cnn_pool = opt['cnn_pool'] mlp_dims = opt['mlp_dims'] mlp_dropout = opt['mlp_dropout'] wd = opt['weight_decay'] base_learn_rate = opt['base_learn_rate'] learn_rate_decay = opt['learn_rate_decay'] steps_per_learn_rate_decay = opt['steps_per_learn_rate_decay'] ############################ # Input definition ############################ with tf.device(get_device_fn(device)): x1 = tf.placeholder( 'float', [None, inp_height, inp_width, inp_depth], name='x1') x2 = tf.placeholder( 'float', [None, inp_height, inp_width, inp_depth], name='x2') phase_train = tf.placeholder('bool', name='phase_train') y_gt = tf.placeholder('float', [None], name='y_gt') global_step = tf.Variable(0.0) ############################ # Feature CNN definition ############################ cnn_channels = [inp_depth] + cnn_depth cnn_nlayers = len(cnn_filter_size) cnn_use_bn = [True] * cnn_nlayers cnn_act = [tf.nn.relu] * cnn_nlayers cnn = nn.cnn(cnn_filter_size, cnn_channels, cnn_pool, cnn_act, cnn_use_bn, phase_train=phase_train, wd=wd, scope='cnn') subsample = np.array(cnn_pool).prod() cnn_h = inp_height / subsample cnn_w = inp_width / subsample feat_dim = cnn_h * cnn_w * cnn_channels[-1] ############################ # Matching MLP definition ############################ mlp_nlayers = len(mlp_dims) mlp_dims = [2 * feat_dim] + mlp_dims mlp_dropout_keep = [1 - mlp_dropout] * mlp_nlayers mlp_act = [tf.nn.relu] * (mlp_nlayers - 1) + [tf.sigmoid] mlp = nn.mlp(mlp_dims, mlp_act, dropout_keep=mlp_dropout_keep, phase_train=phase_train) ############################ # Computation graph ############################ f1 = cnn(x1) f1 = tf.reshape(f1[-1], [-1, feat_dim]) f2 = cnn(x2) f2 = tf.reshape(f2[-1], [-1, feat_dim]) f_join = tf.concat(1, [f1, f2]) y_out = mlp(f_join)[-1] y_out = tf.reshape(y_out, [-1]) ############################ # Loss function ############################ num_ex = tf.shape(y_gt)[0] num_ex_f = tf.to_float(num_ex) bce = f_bce(y_out, y_gt) bce = tf.reduce_sum(bce) / num_ex_f tf.add_to_collection('losses', bce) total_loss = tf.add_n(tf.get_collection('losses'), name='total_loss') ############################ # Statistics ############################ y_out_thresh = tf.to_float(y_out > 0.5) acc = tf.reduce_sum( tf.to_float(tf.equal(y_out_thresh, y_gt))) / num_ex_f #################### # Optimizer #################### learn_rate = tf.train.exponential_decay( base_learn_rate, global_step, steps_per_learn_rate_decay, learn_rate_decay, staircase=True) eps = 1e-7 train_step = tf.train.AdamOptimizer(learn_rate, epsilon=eps).minimize( total_loss, global_step=global_step) ############################ # Computation nodes ############################ model['x1'] = x1 model['x2'] = x2 model['y_gt'] = y_gt model['phase_train'] = phase_train model['y_out'] = y_out model['loss'] = total_loss model['acc'] = acc model['learn_rate'] = learn_rate model['train_step'] = train_step return model
def get_model(opt, device='/cpu:0'): """The original model""" model = {} timespan = opt['timespan'] inp_height = opt['inp_height'] inp_width = opt['inp_width'] inp_depth = opt['inp_depth'] padding = opt['padding'] filter_height = opt['filter_height'] filter_width = opt['filter_width'] ctrl_cnn_filter_size = opt['ctrl_cnn_filter_size'] ctrl_cnn_depth = opt['ctrl_cnn_depth'] ctrl_cnn_pool = opt['ctrl_cnn_pool'] ctrl_rnn_hid_dim = opt['ctrl_rnn_hid_dim'] num_ctrl_mlp_layers = opt['num_ctrl_mlp_layers'] ctrl_mlp_dim = opt['ctrl_mlp_dim'] attn_cnn_filter_size = opt['attn_cnn_filter_size'] attn_cnn_depth = opt['attn_cnn_depth'] attn_dcnn_filter_size = opt['attn_dcnn_filter_size'] attn_dcnn_depth = opt['attn_dcnn_depth'] attn_dcnn_pool = opt['attn_dcnn_pool'] attn_rnn_hid_dim = opt['attn_rnn_hid_dim'] mlp_dropout_ratio = opt['mlp_dropout'] num_attn_mlp_layers = opt['num_attn_mlp_layers'] attn_mlp_depth = opt['attn_mlp_depth'] attn_box_padding_ratio = opt['attn_box_padding_ratio'] wd = opt['weight_decay'] use_bn = opt['use_bn'] use_gt_attn = opt['use_gt_attn'] segm_loss_fn = opt['segm_loss_fn'] box_loss_fn = opt['box_loss_fn'] loss_mix_ratio = opt['loss_mix_ratio'] base_learn_rate = opt['base_learn_rate'] learn_rate_decay = opt['learn_rate_decay'] steps_per_learn_rate_decay = opt['steps_per_learn_rate_decay'] use_attn_rnn = opt['use_attn_rnn'] use_knob = opt['use_knob'] knob_base = opt['knob_base'] knob_decay = opt['knob_decay'] steps_per_knob_decay = opt['steps_per_knob_decay'] use_canvas = opt['use_canvas'] knob_box_offset = opt['knob_box_offset'] knob_segm_offset = opt['knob_segm_offset'] knob_use_timescale = opt['knob_use_timescale'] gt_box_ctr_noise = opt['gt_box_ctr_noise'] gt_box_pad_noise = opt['gt_box_pad_noise'] gt_segm_noise = opt['gt_segm_noise'] squash_ctrl_params = opt['squash_ctrl_params'] use_iou_box = opt['use_iou_box'] fixed_order = opt['fixed_order'] rnd_hflip = opt['rnd_hflip'] rnd_vflip = opt['rnd_vflip'] rnd_transpose = opt['rnd_transpose'] rnd_colour = opt['rnd_colour'] ############################ # Input definition ############################ with tf.device(base.get_device_fn(device)): # Input image, [B, H, W, D] x = tf.placeholder('float', [None, inp_height, inp_width, inp_depth]) x_shape = tf.shape(x) num_ex = x_shape[0] # Groundtruth segmentation, [B, T, H, W] y_gt = tf.placeholder('float', [None, timespan, inp_height, inp_width]) # Groundtruth confidence score, [B, T] s_gt = tf.placeholder('float', [None, timespan]) # Whether in training stage. phase_train = tf.placeholder('bool') phase_train_f = tf.to_float(phase_train) model['x'] = x model['y_gt'] = y_gt model['s_gt'] = s_gt model['phase_train'] = phase_train # Global step global_step = tf.Variable(0.0) # Random image transformation x, y_gt = img.random_transformation( x, y_gt, padding, phase_train, rnd_hflip=rnd_hflip, rnd_vflip=rnd_vflip, rnd_transpose=rnd_transpose, rnd_colour=rnd_colour) model['x_trans'] = x model['y_gt_trans'] = y_gt ############################ # Canvas: external memory ############################ if use_canvas: canvas = tf.zeros(tf.pack([num_ex, inp_height, inp_width, 1])) ccnn_inp_depth = inp_depth + 1 # ccnn_inp_depth = inp_depth acnn_inp_depth = inp_depth + 1 else: ccnn_inp_depth = inp_depth acnn_inp_depth = inp_depth ############################ # Controller CNN definition ############################ ccnn_filters = ctrl_cnn_filter_size ccnn_nlayers = len(ccnn_filters) ccnn_channels = [ccnn_inp_depth] + ctrl_cnn_depth ccnn_pool = ctrl_cnn_pool ccnn_act = [tf.nn.relu] * ccnn_nlayers ccnn_use_bn = [use_bn] * ccnn_nlayers ccnn = nn.cnn(ccnn_filters, ccnn_channels, ccnn_pool, ccnn_act, ccnn_use_bn, phase_train=phase_train, wd=wd, scope='ctrl_cnn', model=model) h_ccnn = [None] * timespan ############################ # Controller RNN definition ############################ ccnn_subsample = np.array(ccnn_pool).prod() crnn_h = inp_height / ccnn_subsample crnn_w = inp_width / ccnn_subsample crnn_dim = ctrl_rnn_hid_dim crnn_inp_dim = crnn_h * crnn_w * ccnn_channels[-1] crnn_state = [None] * (timespan + 1) crnn_g_i = [None] * timespan crnn_g_f = [None] * timespan crnn_g_o = [None] * timespan h_crnn = [None] * timespan crnn_state[-1] = tf.zeros(tf.pack([num_ex, crnn_dim * 2])) crnn_cell = nn.lstm(crnn_inp_dim, crnn_dim, wd=wd, scope='ctrl_lstm', model=model) ############################ # Controller MLP definition ############################ cmlp_dims = [crnn_dim] + [ctrl_mlp_dim] * \ (num_ctrl_mlp_layers - 1) + [9] cmlp_act = [tf.nn.relu] * (num_ctrl_mlp_layers - 1) + [None] cmlp_dropout = None cmlp = nn.mlp(cmlp_dims, cmlp_act, add_bias=True, dropout_keep=cmlp_dropout, phase_train=phase_train, wd=wd, scope='ctrl_mlp', model=model) ########################## # Attention CNN definition ########################## acnn_filters = attn_cnn_filter_size acnn_nlayers = len(acnn_filters) acnn_channels = [acnn_inp_depth] + attn_cnn_depth acnn_pool = [2] * acnn_nlayers acnn_act = [tf.nn.relu] * acnn_nlayers acnn_use_bn = [use_bn] * acnn_nlayers acnn = nn.cnn(acnn_filters, acnn_channels, acnn_pool, acnn_act, acnn_use_bn, phase_train=phase_train, wd=wd, scope='attn_cnn', model=model) x_patch = [None] * timespan h_acnn = [None] * timespan h_acnn_last = [None] * timespan ########################## # Attention RNN definition ########################## acnn_subsample = np.array(acnn_pool).prod() arnn_h = filter_height / acnn_subsample arnn_w = filter_width / acnn_subsample if use_attn_rnn: arnn_dim = attn_rnn_hid_dim arnn_inp_dim = arnn_h * arnn_w * acnn_channels[-1] arnn_state = [None] * (timespan + 1) arnn_g_i = [None] * timespan arnn_g_f = [None] * timespan arnn_g_o = [None] * timespan arnn_state[-1] = tf.zeros(tf.pack([num_ex, arnn_dim * 2])) arnn_cell = nn.lstm(arnn_inp_dim, arnn_dim, wd=wd, scope='attn_lstm') amlp_inp_dim = arnn_dim else: amlp_inp_dim = arnn_h * arnn_w * acnn_channels[-1] ########################## # Attention MLP definition ########################## core_depth = attn_mlp_depth core_dim = arnn_h * arnn_w * core_depth amlp_dims = [amlp_inp_dim] + [core_dim] * num_attn_mlp_layers amlp_act = [tf.nn.relu] * num_attn_mlp_layers amlp_dropout = None # amlp_dropout = [1.0 - mlp_dropout_ratio] * num_attn_mlp_layers amlp = nn.mlp(amlp_dims, amlp_act, dropout_keep=amlp_dropout, phase_train=phase_train, wd=wd, scope='attn_mlp', model=model) ########################## # Score MLP definition ########################## smlp = nn.mlp([crnn_dim + core_dim, 1], [tf.sigmoid], wd=wd, scope='score_mlp', model=model) s_out = [None] * timespan ############################# # Attention DCNN definition ############################# adcnn_filters = attn_dcnn_filter_size adcnn_nlayers = len(adcnn_filters) adcnn_unpool = [2] * (adcnn_nlayers - 1) + [1] adcnn_act = [tf.nn.relu] * adcnn_nlayers adcnn_channels = [attn_mlp_depth] + attn_dcnn_depth adcnn_use_bn = [use_bn] * dcnn_nlayers adcnn_skip_ch = [0] + acnn_channels[::-1][1:] + [ccnn_inp_depth] adcnn = nn.dcnn(adcnn_filters, adcnn_channels, adcnn_unpool, adcnn_act, use_bn=adcnn_use_bn, skip_ch=adcnn_skip_ch, phase_train=phase_train, wd=wd, model=model) h_adcnn = [None] * timespan ########################## # Attention box ########################## attn_box = [None] * timespan iou_soft_box = [None] * timespan const_ones = tf.ones(tf.pack([num_ex, filter_height, filter_width, 1])) attn_box_beta = -5.0 ############################# # Groundtruth attention box ############################# # [B, T, 2] attn_ctr_gt, attn_size_gt, attn_lg_var_gt, attn_box_gt, \ attn_top_left_gt, attn_bot_right_gt = \ base.get_gt_attn(y_gt, padding_ratio=attn_box_padding_ratio, center_shift_ratio=0.0) attn_ctr_gt_noise, attn_size_gt_noise, attn_lg_var_gt_noise, \ attn_box_gt_noise, \ attn_top_left_gt_noise, attn_bot_right_gt_noise = \ base.get_gt_attn(y_gt, padding_ratio=tf.random_uniform( tf.pack([num_ex, timespan, 1]), attn_box_padding_ratio - gt_box_pad_noise, attn_box_padding_ratio + gt_box_pad_noise), center_shift_ratio=tf.random_uniform( tf.pack([num_ex, timespan, 2]), -gt_box_ctr_noise, gt_box_ctr_noise)) attn_delta_gt = _get_delta_from_size( attn_size_gt, filter_height, filter_width) attn_delta_gt_noise = _get_delta_from_size( attn_size_gt_noise, filter_height, filter_width) attn_lg_gamma_gt = tf.zeros(tf.pack([num_ex, timespan, 1])) attn_box_lg_gamma_gt = tf.zeros(tf.pack([num_ex, timespan, 1])) y_out_lg_gamma_gt = tf.zeros(tf.pack([num_ex, timespan, 1])) gtbox_top_left = [None] * timespan gtbox_bot_right = [None] * timespan attn_ctr = [None] * timespan attn_ctr_norm = [None] * timespan attn_delta = [None] * timespan attn_lg_size = [None] * timespan attn_size = [None] * timespan attn_lg_var = [None] * timespan attn_lg_gamma = [None] * timespan attn_gamma = [None] * timespan attn_box_lg_gamma = [None] * timespan attn_box_gamma = [None] * timespan attn_top_left = [None] * timespan attn_bot_right = [None] * timespan ########################## # Groundtruth mix ########################## grd_match_cum = tf.zeros(tf.pack([num_ex, timespan])) # Scale mix ratio on different timesteps. if knob_use_timescale: gt_knob_time_scale = tf.reshape( 1.0 + tf.log(1.0 + tf.to_float(tf.range(timespan)) * 3.0), [1, timespan, 1]) else: gt_knob_time_scale = tf.ones([1, timespan, 1]) # Mix in groundtruth box. global_step_box = tf.maximum(0.0, global_step - knob_box_offset) gt_knob_prob_box = tf.train.exponential_decay( knob_base, global_step_box, steps_per_knob_decay, knob_decay, staircase=False) gt_knob_prob_box = tf.minimum( 1.0, gt_knob_prob_box * gt_knob_time_scale) gt_knob_box = tf.to_float(tf.random_uniform( tf.pack([num_ex, timespan, 1]), 0, 1.0) <= gt_knob_prob_box) model['gt_knob_prob_box'] = gt_knob_prob_box[0, 0, 0] # Mix in groundtruth segmentation. global_step_segm = tf.maximum(0.0, global_step - knob_segm_offset) gt_knob_prob_segm = tf.train.exponential_decay( knob_base, global_step_segm, steps_per_knob_decay, knob_decay, staircase=False) gt_knob_prob_segm = tf.minimum( 1.0, gt_knob_prob_segm * gt_knob_time_scale) gt_knob_segm = tf.to_float(tf.random_uniform( tf.pack([num_ex, timespan, 1]), 0, 1.0) <= gt_knob_prob_segm) model['gt_knob_prob_segm'] = gt_knob_prob_segm[0, 0, 0] ########################## # Segmentation output ########################## y_out = [None] * timespan y_out_lg_gamma = [None] * timespan y_out_beta = -5.0 ########################## # Computation graph ########################## for tt in xrange(timespan): # Controller CNN if use_canvas: ccnn_inp = tf.concat(3, [x, canvas]) acnn_inp = ccnn_inp else: ccnn_inp = x acnn_inp = x h_ccnn[tt] = ccnn(ccnn_inp) h_ccnn_last = h_ccnn[tt][-1] crnn_inp = tf.reshape(h_ccnn_last, [-1, crnn_inp_dim]) # Controller RNN [B, R1] crnn_state[tt], crnn_g_i[tt], crnn_g_f[tt], crnn_g_o[tt] = \ crnn_cell(crnn_inp, crnn_state[tt - 1]) h_crnn[tt] = tf.slice( crnn_state[tt], [0, crnn_dim], [-1, crnn_dim]) if use_gt_attn: attn_ctr[tt] = attn_ctr_gt[:, tt, :] attn_delta[tt] = attn_delta_gt[:, tt, :] attn_lg_var[tt] = attn_lg_var_gt[:, tt, :] attn_lg_gamma[tt] = attn_lg_gamma_gt[:, tt, :] attn_box_lg_gamma[tt] = attn_box_lg_gamma_gt[:, tt, :] y_out_lg_gamma[tt] = y_out_lg_gamma_gt[:, tt, :] else: ctrl_out = cmlp(h_crnn[tt])[-1] attn_ctr_norm[tt] = tf.slice(ctrl_out, [0, 0], [-1, 2]) attn_lg_size[tt] = tf.slice(ctrl_out, [0, 2], [-1, 2]) if squash_ctrl_params: # Restrict to (-1, 1) attn_ctr_norm[tt] = tf.tanh(attn_ctr_norm[tt]) # Restrict to (-inf, 0) attn_lg_size[tt] = -tf.nn.softplus(attn_lg_size[tt]) attn_ctr[tt], attn_size[tt] = base.get_unnormalized_attn( attn_ctr_norm[tt], attn_lg_size[tt], inp_height, inp_width) attn_delta[tt] = _get_delta_from_size( attn_size[tt], filter_height, filter_width) attn_lg_var[tt] = tf.zeros(tf.pack([num_ex, 2])) attn_lg_gamma[tt] = tf.slice(ctrl_out, [0, 6], [-1, 1]) attn_box_lg_gamma[tt] = tf.slice(ctrl_out, [0, 7], [-1, 1]) y_out_lg_gamma[tt] = tf.slice(ctrl_out, [0, 8], [-1, 1]) attn_gamma[tt] = tf.reshape( tf.exp(attn_lg_gamma[tt]), [-1, 1, 1, 1]) attn_box_gamma[tt] = tf.reshape( tf.exp(attn_box_lg_gamma[tt]), [-1, 1, 1, 1]) y_out_lg_gamma[tt] = tf.reshape(y_out_lg_gamma[tt], [-1, 1, 1, 1]) # Initial filters (predicted) filter_y = base.get_gaussian_filter( attn_ctr[tt][:, 0], attn_size[tt][:, 0], attn_lg_var[tt][:, 0], inp_height, filter_height) filter_x = base.get_gaussian_filter( attn_ctr[tt][:, 1], attn_size[tt][:, 1], attn_lg_var[tt][:, 1], inp_width, filter_width) filter_y_inv = tf.transpose(filter_y, [0, 2, 1]) filter_x_inv = tf.transpose(filter_x, [0, 2, 1]) # Attention box if use_iou_box: _idx_map = base.get_idx_map( tf.pack([num_ex, inp_height, inp_width])) attn_top_left[tt], attn_bot_right[tt] = _get_attn_coord( attn_ctr[tt], attn_delta[tt], filter_height, filter_width) attn_box[tt] = base.get_filled_box_idx( _idx_map, attn_top_left[tt], attn_bot_right[tt]) attn_box[tt] = tf.reshape(attn_box[tt], [-1, 1, inp_height, inp_width]) else: attn_box[tt] = base.extract_patch(const_ones * attn_box_gamma[tt], filter_y_inv, filter_x_inv, 1) attn_box[tt] = tf.sigmoid(attn_box[tt] + attn_box_beta) attn_box[tt] = tf.reshape(attn_box[tt], [-1, 1, inp_height, inp_width]) # Kick in GT bbox. if use_knob: # IOU [B, 1, T] if use_iou_box: _top_left = tf.expand_dims(attn_top_left[tt], 1) _bot_right = tf.expand_dims(attn_bot_right[tt], 1) if not fixed_order: iou_soft_box[tt] = base.f_iou_box( _top_left, _bot_right, attn_top_left_gt, attn_bot_right_gt) else: if not fixed_order: iou_soft_box[tt] = base.f_inter( attn_box[tt], attn_box_gt) / \ base.f_union(attn_box[tt], attn_box_gt, eps=1e-5) if fixed_order: attn_ctr_gtm = attn_ctr_gt_noise[:, tt, :] attn_delta_gtm = attn_delta_gt_noise[:, tt, :] attn_size_gtm = attn_size_gt_noise[:, tt, :] else: grd_match = base.f_greedy_match( iou_soft_box[tt], grd_match_cum) # Let's try not using cumulative match. # grd_match_cum += grd_match # [B, T, 1] grd_match = tf.expand_dims(grd_match, 2) attn_ctr_gtm = tf.reduce_sum( grd_match * attn_ctr_gt_noise, 1) attn_delta_gtm = tf.reduce_sum( grd_match * attn_delta_gt_noise, 1) attn_size_gtm = tf.reduce_sum( grd_match * attn_size_gt_noise, 1) _gt_knob_box = gt_knob_box attn_ctr[tt] = phase_train_f * _gt_knob_box[:, tt, 0: 1] * \ attn_ctr_gtm + \ (1 - phase_train_f * _gt_knob_box[:, tt, 0: 1]) * \ attn_ctr[tt] attn_delta[tt] = phase_train_f * _gt_knob_box[:, tt, 0: 1] * \ attn_delta_gtm + \ (1 - phase_train_f * _gt_knob_box[:, tt, 0: 1]) * \ attn_delta[tt] attn_size[tt] = phase_train_f * _gt_knob_box[:, tt, 0: 1] * \ attn_size_gtm + \ (1 - phase_train_f * _gt_knob_box[:, tt, 0: 1]) * \ attn_size[tt] attn_top_left[tt], attn_bot_right[tt] = _get_attn_coord( attn_ctr[tt], attn_delta[tt], filter_height, filter_width) filter_y = base.get_gaussian_filter( attn_ctr[tt][:, 0], attn_size[tt][:, 0], attn_lg_var[tt][:, 0], inp_height, filter_height) filter_x = base.get_gaussian_filter( attn_ctr[tt][:, 1], attn_size[tt][:, 1], attn_lg_var[tt][:, 1], inp_width, filter_width) filter_y_inv = tf.transpose(filter_y, [0, 2, 1]) filter_x_inv = tf.transpose(filter_x, [0, 2, 1]) # Attended patch [B, A, A, D] x_patch[tt] = attn_gamma[tt] * base.extract_patch( acnn_inp, filter_y, filter_x, acnn_inp_depth) # CNN [B, A, A, D] => [B, RH2, RW2, RD2] h_acnn[tt] = acnn(x_patch[tt]) h_acnn_last[tt] = h_acnn[tt][-1] if use_attn_rnn: # RNN [B, T, R2] arnn_inp = tf.reshape(h_acnn_last[tt], [-1, arnn_inp_dim]) arnn_state[tt], arnn_g_i[tt], arnn_g_f[tt], arnn_g_o[tt] = \ arnn_cell(arnn_inp, arnn_state[tt - 1]) # Dense segmentation network [B, R] => [B, M] if use_attn_rnn: h_arnn = tf.slice( arnn_state[tt], [0, arnn_dim], [-1, arnn_dim]) amlp_inp = h_arnn else: amlp_inp = h_acnn_last[tt] amlp_inp = tf.reshape(amlp_inp, [-1, amlp_inp_dim]) h_core = amlp(amlp_inp)[-1] h_core_img = tf.reshape( h_core, [-1, arnn_h, arnn_w, attn_mlp_depth]) # DCNN skip = [None] + h_acnn[tt][::-1][1:] + [x_patch[tt]] h_adcnn[tt] = adcnn(h_core_img, skip=skip) # Output y_out[tt] = base.extract_patch( h_adcnn[tt][-1], filter_y_inv, filter_x_inv, 1) y_out[tt] = tf.exp(y_out_lg_gamma[tt]) * y_out[tt] + y_out_beta y_out[tt] = tf.sigmoid(y_out[tt]) y_out[tt] = tf.reshape(y_out[tt], [-1, 1, inp_height, inp_width]) # Scoring network smlp_inp = tf.concat(1, [h_crnn[tt], h_core]) s_out[tt] = smlp(smlp_inp)[-1] # Here is the knob kick in GT segmentations at this timestep. # [B, N, 1, 1] if use_canvas: if use_knob: _gt_knob_segm = tf.expand_dims( tf.expand_dims(gt_knob_segm[:, tt, 0: 1], 2), 3) if fixed_order: _y_out = tf.expand_dims(y_gt[:, tt, :, :], 3) else: grd_match = tf.expand_dims(grd_match, 3) _y_out = tf.expand_dims(tf.reduce_sum( grd_match * y_gt, 1), 3) # Add independent uniform noise to groundtruth. _noise = tf.random_uniform( tf.pack([num_ex, inp_height, inp_width, 1]), 0, 0.3) _y_out = _y_out - _y_out * _noise _y_out = phase_train_f * _gt_knob_segm * _y_out + \ (1 - phase_train_f * _gt_knob_segm) * \ tf.reshape(y_out[tt], [-1, inp_height, inp_width, 1]) else: _y_out = tf.reshape(y_out[tt], [-1, inp_height, inp_width, 1]) canvas += tf.stop_gradient(_y_out) ######################### # Model outputs ######################### s_out = tf.concat(1, s_out) model['s_out'] = s_out y_out = tf.concat(1, y_out) model['y_out'] = y_out attn_box = tf.concat(1, attn_box) model['attn_box'] = attn_box x_patch = tf.concat(1, [tf.expand_dims(x_patch[tt], 1) for tt in xrange(timespan)]) model['x_patch'] = x_patch attn_top_left = tf.concat(1, [tf.expand_dims(tmp, 1) for tmp in attn_top_left]) attn_bot_right = tf.concat(1, [tf.expand_dims(tmp, 1) for tmp in attn_bot_right]) attn_ctr = tf.concat(1, [tf.expand_dims(tmp, 1) for tmp in attn_ctr]) attn_size = tf.concat(1, [tf.expand_dims(tmp, 1) for tmp in attn_size]) attn_lg_gamma = tf.concat(1, [tf.expand_dims(tmp, 1) for tmp in attn_lg_gamma]) attn_box_lg_gamma = tf.concat(1, [tf.expand_dims(tmp, 1) for tmp in attn_box_lg_gamma]) y_out_lg_gamma = tf.concat(1, [tf.expand_dims(tmp, 1) for tmp in y_out_lg_gamma]) model['attn_ctr'] = attn_ctr model['attn_size'] = attn_size model['attn_top_left'] = attn_top_left model['attn_bot_right'] = attn_bot_right model['attn_box_gt'] = attn_box_gt ######################### # Loss function ######################### y_gt_shape = tf.shape(y_gt) num_ex_f = tf.to_float(y_gt_shape[0]) max_num_obj = tf.to_float(y_gt_shape[1]) ############################ # Box loss ############################ if fixed_order: # [B, T] for fixed order. iou_soft_box = base.f_iou( attn_box, attn_box_gt, pairwise=False) else: if use_knob: # [B, T, T] for matching. iou_soft_box = tf.concat( 1, [tf.expand_dims(iou_soft_box[tt], 1) for tt in xrange(timespan)]) else: iou_soft_box = base.f_iou(attn_box, attn_box_gt, timespan, pairwise=True) identity_match = base.get_identity_match(num_ex, timespan, s_gt) if fixed_order: match_box = identity_match else: match_box = base.f_segm_match(iou_soft_box, s_gt) model['match_box'] = match_box match_sum_box = tf.reduce_sum(match_box, reduction_indices=[2]) match_count_box = tf.reduce_sum(match_sum_box, reduction_indices=[1]) match_count_box = tf.maximum(1.0, match_count_box) # [B] if fixed order, [B, T] if matching. if fixed_order: iou_soft_box_mask = iou_soft_box else: iou_soft_box_mask = tf.reduce_sum( iou_soft_box * match_box, [1]) iou_soft_box = tf.reduce_sum(iou_soft_box_mask, [1]) iou_soft_box = tf.reduce_sum( iou_soft_box / match_count_box) / num_ex_f if box_loss_fn == 'iou': box_loss = -iou_soft_box elif box_loss_fn == 'wt_cov': box_loss = -base.f_weighted_coverage(iou_soft_box, attn_box_gt) elif box_loss_fn == 'mse': box_loss_fn = base.f_match_loss( y_out, y_gt, match_box, timespan, f_mse) elif box_loss_fn == 'bce': box_loss_fn = base.f_match_loss( y_out, y_gt, match_box, timespan, f_bce) else: raise Exception('Unknown box_loss_fn: {}'.format(box_loss_fn)) model['box_loss'] = box_loss box_loss_coeff = tf.constant(1.0) model['box_loss_coeff'] = box_loss_coeff tf.add_to_collection('losses', box_loss_coeff * box_loss) ############################## # Segmentation loss ############################## # IoU (soft) iou_soft_pairwise = base.f_iou(y_out, y_gt, timespan, pairwise=True) real_match = base.f_segm_match(iou_soft_pairwise, s_gt) if fixed_order: iou_soft = base.f_iou(y_out, y_gt, pairwise=False) match = identity_match else: iou_soft = iou_soft_pairwise match = real_match model['match'] = match match_sum = tf.reduce_sum(match, reduction_indices=[2]) match_count = tf.reduce_sum(match_sum, reduction_indices=[1]) match_count = tf.maximum(1.0, match_count) # Weighted coverage (soft) wt_cov_soft = base.f_weighted_coverage(iou_soft_pairwise, y_gt) model['wt_cov_soft'] = wt_cov_soft unwt_cov_soft = base.f_unweighted_coverage( iou_soft_pairwise, match_count) model['unwt_cov_soft'] = unwt_cov_soft # [B] if fixed order, [B, T] if matching. if fixed_order: iou_soft_mask = iou_soft else: iou_soft_mask = tf.reduce_sum(iou_soft * match, [1]) iou_soft = tf.reduce_sum(iou_soft_mask, [1]) iou_soft = tf.reduce_sum(iou_soft / match_count) / num_ex_f model['iou_soft'] = iou_soft if segm_loss_fn == 'iou': segm_loss = -iou_soft elif segm_loss_fn == 'wt_cov': segm_loss = -wt_cov_soft elif segm_loss_fn == 'bce': segm_loss = f_match_bce(y_out, y_gt, match, timespan) else: raise Exception('Unknown segm_loss_fn: {}'.format(segm_loss_fn)) model['segm_loss'] = segm_loss segm_loss_coeff = 1.0 tf.add_to_collection('losses', segm_loss_coeff * segm_loss) #################### # Score loss #################### conf_loss = base.f_conf_loss(s_out, match, timespan, use_cum_min=True) model['conf_loss'] = conf_loss tf.add_to_collection('losses', loss_mix_ratio * conf_loss) #################### # Total loss #################### total_loss = tf.add_n(tf.get_collection( 'losses'), name='total_loss') model['loss'] = total_loss #################### # Optimizer #################### learn_rate = tf.train.exponential_decay( base_learn_rate, global_step, steps_per_learn_rate_decay, learn_rate_decay, staircase=True) model['learn_rate'] = learn_rate eps = 1e-7 train_step = GradientClipOptimizer( tf.train.AdamOptimizer(learn_rate, epsilon=eps), clip=1.0).minimize(total_loss, global_step=global_step) model['train_step'] = train_step #################### # Statistics #################### # Here statistics (hard measures) is always using matching. y_out_hard = tf.to_float(y_out > 0.5) iou_hard = base.f_iou(y_out_hard, y_gt, timespan, pairwise=True) wt_cov_hard = base.f_weighted_coverage(iou_hard, y_gt) model['wt_cov_hard'] = wt_cov_hard unwt_cov_hard = base.f_unweighted_coverage(iou_hard, match_count) model['unwt_cov_hard'] = unwt_cov_hard iou_hard_mask = tf.reduce_sum(iou_hard * real_match, [1]) iou_hard = tf.reduce_sum(tf.reduce_sum(iou_hard_mask, [1]) / match_count) / num_ex_f model['iou_hard'] = iou_hard dice = base.f_dice(y_out_hard, y_gt, timespan, pairwise=True) dice = tf.reduce_sum(tf.reduce_sum( dice * real_match, reduction_indices=[1, 2]) / match_count) / num_ex_f model['dice'] = dice model['count_acc'] = _count_acc(s_out, s_gt) model['dic'] = _dic(s_out, s_gt, abs=False) model['dic_abs'] = _dic(s_out, s_gt, abs=True) ################################ # Controller output statistics ################################ attn_lg_gamma_mean = tf.reduce_sum(attn_lg_gamma) / num_ex_f / timespan attn_box_lg_gamma_mean = tf.reduce_sum( attn_box_lg_gamma) / num_ex_f / timespan y_out_lg_gamma_mean = tf.reduce_sum( y_out_lg_gamma) / num_ex_f / timespan model['attn_lg_gamma_mean'] = attn_lg_gamma_mean model['attn_box_lg_gamma_mean'] = attn_box_lg_gamma_mean model['y_out_lg_gamma_mean'] = y_out_lg_gamma_mean ################################## # Controller RNN gate statistics ################################## crnn_g_i = tf.concat(1, [tf.expand_dims(tmp, 1) for tmp in crnn_g_i]) crnn_g_f = tf.concat(1, [tf.expand_dims(tmp, 1) for tmp in crnn_g_f]) crnn_g_o = tf.concat(1, [tf.expand_dims(tmp, 1) for tmp in crnn_g_o]) crnn_g_i_avg = tf.reduce_sum( crnn_g_i) / num_ex_f / timespan / ctrl_rnn_hid_dim crnn_g_f_avg = tf.reduce_sum( crnn_g_f) / num_ex_f / timespan / ctrl_rnn_hid_dim crnn_g_o_avg = tf.reduce_sum( crnn_g_o) / num_ex_f / timespan / ctrl_rnn_hid_dim model['crnn_g_i_avg'] = crnn_g_i_avg model['crnn_g_f_avg'] = crnn_g_f_avg model['crnn_g_o_avg'] = crnn_g_o_avg return model
def get_model(opt, device='/cpu:0'): """CNN -> -> RNN -> DCNN -> Instances""" model = {} timespan = opt['timespan'] inp_height = opt['inp_height'] inp_width = opt['inp_width'] inp_depth = opt['inp_depth'] padding = opt['padding'] rnn_type = opt['rnn_type'] cnn_filter_size = opt['cnn_filter_size'] cnn_depth = opt['cnn_depth'] dcnn_filter_size = opt['dcnn_filter_size'] dcnn_depth = opt['dcnn_depth'] conv_lstm_filter_size = opt['conv_lstm_filter_size'] conv_lstm_hid_depth = opt['conv_lstm_hid_depth'] rnn_hid_dim = opt['rnn_hid_dim'] mlp_depth = opt['mlp_depth'] wd = opt['weight_decay'] segm_dense_conn = opt['segm_dense_conn'] use_bn = opt['use_bn'] use_deconv = opt['use_deconv'] add_skip_conn = opt['add_skip_conn'] score_use_core = opt['score_use_core'] loss_mix_ratio = opt['loss_mix_ratio'] base_learn_rate = opt['base_learn_rate'] learn_rate_decay = opt['learn_rate_decay'] steps_per_learn_rate_decay = opt['steps_per_learn_rate_decay'] num_mlp_layers = opt['num_mlp_layers'] mlp_dropout_ratio = opt['mlp_dropout'] segm_loss_fn = opt['segm_loss_fn'] clip_gradient = opt['clip_gradient'] rnd_hflip = opt['rnd_hflip'] rnd_vflip = opt['rnd_vflip'] rnd_transpose = opt['rnd_transpose'] rnd_colour = opt['rnd_colour'] with tf.device(base.get_device_fn(device)): # Input image, [B, H, W, D] x = tf.placeholder('float', [None, inp_height, inp_width, inp_depth]) # Whether in training stage, required for batch norm. phase_train = tf.placeholder('bool') # Groundtruth segmentation maps, [B, T, H, W] y_gt = tf.placeholder( 'float', [None, timespan, inp_height, inp_width]) # Groundtruth confidence score, [B, T] s_gt = tf.placeholder('float', [None, timespan]) model['x'] = x model['phase_train'] = phase_train model['y_gt'] = y_gt model['s_gt'] = s_gt x_shape = tf.shape(x) num_ex = x_shape[0] # Random image transformation x, y_gt = img.random_transformation( x, y_gt, padding, phase_train, rnd_hflip=rnd_hflip, rnd_vflip=rnd_vflip, rnd_transpose=rnd_transpose, rnd_colour=rnd_colour) model['x_trans'] = x model['y_gt_trans'] = y_gt # CNN cnn_filters = cnn_filter_size cnn_nlayers = len(cnn_filters) cnn_channels = [inp_depth] + cnn_depth cnn_pool = [2] * cnn_nlayers cnn_act = [tf.nn.relu] * cnn_nlayers cnn_use_bn = [use_bn] * cnn_nlayers cnn = nn.cnn(cnn_filters, cnn_channels, cnn_pool, cnn_act, cnn_use_bn, phase_train=phase_train, wd=wd, model=model) h_cnn = cnn(x) h_pool3 = h_cnn[-1] # RNN input size subsample = np.array(cnn_pool).prod() rnn_h = inp_height / subsample rnn_w = inp_width / subsample # Low-res segmentation depth core_depth = mlp_depth if segm_dense_conn else 1 core_dim = rnn_h * rnn_w * core_depth rnn_state = [None] * (timespan + 1) # RNN if rnn_type == 'conv_lstm': rnn_depth = conv_lstm_hid_depth rnn_dim = rnn_h * rnn_w * rnn_depth conv_lstm_inp_depth = cnn_channels[-1] rnn_inp = h_pool3 rnn_state[-1] = tf.zeros(tf.pack([num_ex, rnn_h, rnn_w, rnn_depth * 2])) rnn_cell = nn.conv_lstm(conv_lstm_inp_depth, rnn_depth, conv_lstm_filter_size, wd=wd) elif rnn_type == 'lstm' or rnn_type == 'gru': rnn_dim = rnn_hid_dim rnn_inp_dim = rnn_h * rnn_w * cnn_channels[-1] rnn_inp = tf.reshape( h_pool3, [-1, rnn_h * rnn_w * cnn_channels[-1]]) if rnn_type == 'lstm': rnn_state[-1] = tf.zeros(tf.pack([num_ex, rnn_hid_dim * 2])) rnn_cell = nn.lstm(rnn_inp_dim, rnn_hid_dim, wd=wd) else: rnn_state[-1] = tf.zeros(tf.pack([num_ex, rnn_hid_dim])) rnn_cell = nn.gru(rnn_inp_dim, rnn_hid_dim, wd=wd) else: raise Exception('Unknown RNN type: {}'.format(rnn_type)) for tt in xrange(timespan): rnn_state[tt], _gi, _gf, _go = rnn_cell(rnn_inp, rnn_state[tt - 1]) if rnn_type == 'conv_lstm': h_rnn = [tf.slice(rnn_state[tt], [0, 0, 0, rnn_depth], [-1, -1, -1, rnn_depth]) for tt in xrange(timespan)] elif rnn_type == 'lstm': h_rnn = [tf.slice(rnn_state[tt], [0, rnn_dim], [-1, rnn_dim]) for tt in xrange(timespan)] elif rnn_type == 'gru': h_rnn = state h_rnn_all = tf.concat( 1, [tf.expand_dims(h_rnn[tt], 1) for tt in xrange(timespan)]) # Core segmentation network. if segm_dense_conn: # Dense segmentation network h_rnn_all = tf.reshape(h_rnn_all, [-1, rnn_dim]) mlp_dims = [rnn_dim] + [core_dim] * num_mlp_layers mlp_act = [tf.nn.relu] * num_mlp_layers mlp_dropout = [1.0 - mlp_dropout_ratio] * num_mlp_layers segm_mlp = nn.mlp(mlp_dims, mlp_act, mlp_dropout, phase_train=phase_train, wd=wd) h_core = segm_mlp(h_rnn_all)[-1] h_core = tf.reshape(h_core, [-1, rnn_h, rnn_w, mlp_depth]) else: # Convolutional segmentation netowrk w_segm_conv = nn.weight_variable([3, 3, rnn_depth, 1], wd=wd) b_segm_conv = nn.weight_variable([1], wd=wd) b_log_softmax = nn.weight_variable([1]) h_rnn_all = tf.reshape( h_rnn_all, [-1, rnn_h, rnn_w, rnn_depth]) h_core = tf.reshape(tf.log(tf.nn.softmax(tf.reshape( nn.conv2d(h_rnn_all, w_segm_conv) + b_segm_conv, [-1, rnn_h * rnn_w]))) + b_log_softmax, [-1, rnn_h, rnn_w, 1]) # Deconv net to upsample if use_deconv: dcnn_filters = dcnn_filter_size dcnn_nlayers = len(dcnn_filters) dcnn_unpool = [2] * (dcnn_nlayers - 1) + [1] dcnn_act = [tf.nn.relu] * (dcnn_nlayers - 1) + [tf.sigmoid] if segm_dense_conn: dcnn_channels = [mlp_depth] + dcnn_depth else: dcnn_channels = [1] * (dcnn_nlayers + 1) dcnn_use_bn = [use_bn] * dcnn_nlayers skip = None skip_ch = None if add_skip_conn: skip, skip_ch = build_skip_conn( cnn_channels, h_cnn, x, timespan) dcnn = nn.dcnn(dcnn_filters, dcnn_channels, dcnn_unpool, dcnn_act, dcnn_use_bn, skip_ch=skip_ch, phase_train=phase_train, wd=wd, model=model) h_dcnn = dcnn(h_core, skip=skip) y_out = tf.reshape( h_dcnn[-1], [-1, timespan, inp_height, inp_width]) else: y_out = tf.reshape( tf.image.resize_bilinear(h_core, [inp_height, inp_wiidth]), [-1, timespan, inp_height, inp_width]) model['y_out'] = y_out # Scoring network if score_use_core: # Use core network to predict score score_inp = h_core score_inp_shape = [-1, core_dim] score_inp = tf.reshape(score_inp, score_inp_shape) score_dim = core_dim else: # Use RNN hidden state to predict score score_inp = h_rnn_all if rnn_type == 'conv_lstm': score_inp_shape = [-1, rnn_h, rnn_w, rnn_depth] score_inp = tf.reshape(score_inp, score_inp_shape) score_maxpool = opt['score_maxpool'] score_dim = rnn_h * rnn_w / (score_maxpool ** 2) * rnn_depth if score_maxpool > 1: score_inp = nn.max_pool(score_inp, score_maxpool) score_inp = tf.reshape(score_inp, [-1, score_dim]) else: score_inp_shape = [-1, rnn_dim] score_inp = tf.reshape(score_inp, score_inp_shape) score_dim = rnn_dim score_mlp = nn.mlp(dims=[score_dim, 1], act=[tf.sigmoid], wd=wd) s_out = score_mlp(score_inp)[-1] s_out = tf.reshape(s_out, [-1, timespan]) model['s_out'] = s_out # Loss function global_step = tf.Variable(0.0) learn_rate = tf.train.exponential_decay( base_learn_rate, global_step, steps_per_learn_rate_decay, learn_rate_decay, staircase=True) model['learn_rate'] = learn_rate eps = 1e-7 y_gt_shape = tf.shape(y_gt) num_ex = tf.to_float(y_gt_shape[0]) max_num_obj = tf.to_float(y_gt_shape[1]) # Pairwise IOU iou_soft = base.f_iou(y_out, y_gt, timespan, pairwise=True) # Matching match = base.f_segm_match(iou_soft, s_gt) model['match'] = match match_sum = tf.reduce_sum(match, reduction_indices=[2]) match_count = tf.reduce_sum(match_sum, reduction_indices=[1]) # Weighted coverage (soft) wt_cov_soft = base.f_weighted_coverage(iou_soft, y_gt) model['wt_cov_soft'] = wt_cov_soft unwt_cov_soft = base.f_unweighted_coverage(iou_soft, match_count) model['unwt_cov_soft'] = unwt_cov_soft # IOU (soft) iou_soft_mask = tf.reduce_sum(iou_soft * match, [1]) iou_soft = tf.reduce_sum(tf.reduce_sum(iou_soft_mask, [1]) / match_count) / num_ex model['iou_soft'] = iou_soft gt_wt = coverage_weight(y_gt) wt_iou_soft = tf.reduce_sum(tf.reduce_sum(iou_soft_mask * gt_wt, [1]) / match_count) / num_ex model['wt_iou_soft'] = wt_iou_soft if segm_loss_fn == 'iou': segm_loss = -iou_soft elif segm_loss_fn == 'wt_iou': segm_loss = -wt_iou_soft elif segm_loss_fn == 'wt_cov': segm_loss = -wt_cov_soft elif segm_loss_fn == 'bce': segm_loss = base.f_match_loss( y_out, y_gt, match, timespan, base.f_bce) model['segm_loss'] = segm_loss conf_loss = base.f_conf_loss(s_out, match, timespan, use_cum_min=True) model['conf_loss'] = conf_loss loss = loss_mix_ratio * conf_loss + segm_loss model['loss'] = loss tf.add_to_collection('losses', loss) total_loss = tf.add_n(tf.get_collection( 'losses'), name='total_loss') model['total_loss'] = total_loss train_step = GradientClipOptimizer( tf.train.AdamOptimizer(learn_rate, epsilon=eps), clip=clip_gradient).minimize(total_loss, global_step=global_step) model['train_step'] = train_step # Statistics # [B, M, N] * [B, M, N] => [B] * [B] => [1] y_out_hard = tf.to_float(y_out > 0.5) iou_hard = base.f_iou(y_out_hard, y_gt, timespan, pairwise=True) wt_cov_hard = base.f_weighted_coverage(iou_hard, y_gt) model['wt_cov_hard'] = wt_cov_hard unwt_cov_hard = base.f_unweighted_coverage(iou_hard, match_count) model['unwt_cov_hard'] = unwt_cov_hard # [B, T] iou_hard_mask = tf.reduce_sum(iou_hard * match, [1]) iou_hard = base.f_iou(tf.to_float(y_out > 0.5), y_gt, timespan, pairwise=True) iou_hard = tf.reduce_sum(tf.reduce_sum( iou_hard * match, reduction_indices=[1, 2]) / match_count) / num_ex model['iou_hard'] = iou_hard wt_iou_hard = tf.reduce_sum(tf.reduce_sum(iou_hard_mask * gt_wt, [1]) / match_count) / num_ex model['wt_iou_hard'] = wt_iou_hard dice = base.f_dice(y_out_hard, y_gt, timespan, pairwise=True) dice = tf.reduce_sum(tf.reduce_sum(dice * match, [1, 2]) / match_count) / num_ex model['dice'] = dice model['count_acc'] = base.f_count_acc(s_out, s_gt) model['dic'] = base.f_dic(s_out, s_gt, abs=False) model['dic_abs'] = base.f_dic(s_out, s_gt, abs=True) return model