def test_create_csn(self): """ Test simple CSN with different inputs. """ for input_channel, input_clip_length, input_crop_size in itertools.product( (3, 2), (4, 8), (56, 64)): stage_spatial_stride = (1, 2, 2, 2) stage_temporal_stride = (1, 2, 2, 1) total_spatial_stride = 2 * np.prod(stage_spatial_stride) total_temporal_stride = np.prod(stage_temporal_stride) head_pool_kernel_size = ( input_clip_length // total_temporal_stride, input_crop_size // total_spatial_stride, input_crop_size // total_spatial_stride, ) model = create_csn( input_channel=input_channel, model_depth=50, model_num_class=400, dropout_rate=0, norm=nn.BatchNorm3d, activation=nn.ReLU, stem_dim_out=8, stem_conv_kernel_size=(3, 7, 7), stem_conv_stride=(1, 2, 2), stage_conv_a_kernel_size=(1, 1, 1), stage_conv_b_kernel_size=(3, 3, 3), stage_conv_b_width_per_group=1, stage_spatial_stride=(1, 2, 2, 2), stage_temporal_stride=(1, 2, 2, 1), bottleneck=create_bottleneck_block, head_pool=nn.AvgPool3d, head_pool_kernel_size=head_pool_kernel_size, head_output_size=(1, 1, 1), head_activation=nn.Softmax, ) # Test forwarding. for tensor in TestCSN._get_inputs(input_channel, input_clip_length, input_crop_size): if tensor.shape[1] != input_channel: with self.assertRaises(RuntimeError): out = model(tensor) continue out = model(tensor) output_shape = out.shape output_shape_gt = (tensor.shape[0], 400) self.assertEqual( output_shape, output_shape_gt, "Output shape {} is different from expected shape {}". format(output_shape, output_shape_gt), )
def _construct_network(self, cfg): """ Builds a single pathway ResNet model. Args: cfg (CfgNode): model building configs, details are in the comments of the config file. """ # Params from configs. norm_module = get_norm(cfg) self.model = create_csn( # Input clip configs. input_channel=cfg.DATA.INPUT_CHANNEL_NUM[0], # Model configs. model_depth=cfg.RESNET.DEPTH, model_num_class=cfg.MODEL.NUM_CLASSES, dropout_rate=cfg.MODEL.DROPOUT_RATE, # Normalization configs. norm=norm_module, # Activation configs. activation=partial(nn.ReLU, inplace=cfg.RESNET.INPLACE_RELU), # Stem configs. stem_dim_out=cfg.RESNET.WIDTH_PER_GROUP, stem_conv_kernel_size=(3, 7, 7), stem_conv_stride=(1, 2, 2), stem_pool=nn.MaxPool3d, stem_pool_kernel_size=(1, 3, 3), stem_pool_stride=(1, 2, 2), # Stage configs. stage_conv_a_kernel_size=(1, 1, 1), stage_conv_b_kernel_size=(3, 3, 3), stage_conv_b_width_per_group=1, stage_spatial_stride=(1, 2, 2, 2), stage_temporal_stride=(1, 2, 2, 2), bottleneck=create_bottleneck_block, # Head configs. head_pool=nn.AvgPool3d, head_pool_kernel_size=( cfg.DATA.NUM_FRAMES // 8, cfg.DATA.TRAIN_CROP_SIZE // 32, cfg.DATA.TRAIN_CROP_SIZE // 32, ), head_activation=None, head_output_with_global_average=False, ) self.post_act = get_head_act(cfg.MODEL.HEAD_ACT)