def resnet_block(input, strides, out_channels, name): t = make_conv2d(input_tensor=input, filter_shape=(1, 1, input.shape[1].value, out_channels), strides=(1, 1, 1, 1), padding="SAME", actimode="RELU", name=name + "_conv1") t = make_conv2d(input_tensor=t, filter_shape=(3, 3, out_channels, out_channels), strides=strides, padding="SAME", actimode="RELU", name=name + "_conv2") t = make_conv2d(input_tensor=t, filter_shape=(1, 1, out_channels, out_channels * 4), strides=(1, 1, 1, 1), padding="SAME", actimode="NONE", name=name + "_conv3") if (strides[2] > 1) or (input.shape[1].value != out_channels * 4): input = make_conv2d(input_tensor=input, filter_shape=(1, 1, input.shape[1].value, out_channels * 4), strides=strides, padding="SAME", actimode="RELU", name=name + "_conv4") return tf.nn.relu(tf.add(input, t))
def resnext_block(input, strides, out_channels, groups, name): t = make_conv2d(input_tensor=input, filter_shape=(1, 1, input.shape[1].value, out_channels), strides=(1, 1, 1, 1), padding="SAME", actimode="RELU", name=name + "_conv1") t = tf.split(t, groups, axis=1, name=name + "_split") assert (len(t) == groups) for i in range(groups): t[i] = make_conv2d(input_tensor=t[i], filter_shape=(3, 3, t[i].shape[1].value, out_channels // groups), strides=strides, padding="SAME", actimode="RELU", name=name + "_conv2_".format(i)) output = tf.concat(t, axis=1, name=name + "_concat") t = make_conv2d(input_tensor=output, filter_shape=(1, 1, output.shape[1].value, 2 * out_channels), strides=(1, 1, 1, 1), padding="SAME", actimode="NONE", name=name + "_conv3") if (strides[2] > 1) or (input.shape[1].value != out_channels * 2): input = make_conv2d(input_tensor=input, filter_shape=(1, 1, input.shape[1].value, out_channels * 2), strides=strides, padding="SAME", actimode="RELU", name=name + "_conv4") return tf.nn.relu(tf.add(input, t))
def squeeze(out_channels, input): return make_conv2d(input_tensor=input, filter_shape=(1, 1, input.shape[1].value, out_channels), strides=(1, 1, 1, 1), padding="SAME", actimode="RELU", name="squeeze")
def fit(current, input): if (input.shape[2].value == current.shape[2].value): return squeeze(current.shape[1].value, input) else: return make_conv2d(input_tensor=input, filter_shape=(3, 3, input.shape[1].value, current.shape[1].value), strides=(1, 1, 2, 2), padding="SAME", actimode="RELU", name="fit")