Exemplo n.º 1
0
    def __init__(self, pretrained: bool = True, frame_num=4, resnet_3d_pretrained_model='r3d18_KM_200ep.pth'):
        super().__init__()

        resnet18_3d = generate_model(18)
        if pretrained:
            checkpoint = torch.load(resnet_3d_pretrained_model)
            resnet18_3d.fc = nn.Linear(512, 1039)
            resnet18_3d.load_state_dict(checkpoint['state_dict'])

        self.resnet18_3d = resnet18_3d
        resnet18_3d_last_dim = 512
        cnn_last_dim = 1024
        self.cnn_last = nn.Sequential(
            nn.Linear(resnet18_3d_last_dim, cnn_last_dim),
            nn.BatchNorm1d(cnn_last_dim),
            nn.ReLU(inplace=True),
            nn.Dropout()
        )

        comb_fc1_out = 512
        self.comb_fc1 = nn.Sequential(
            nn.Linear(cnn_last_dim * 2, comb_fc1_out),
            # nn.Linear(resnet18_last_dim * 2, comb_fc1_out),
            nn.BatchNorm1d(comb_fc1_out),
            nn.ReLU(inplace=True),
            nn.Dropout()
        )
        class_num = math.factorial(frame_num) // 2
        self.combination_list = list(itertools.combinations(list(range(frame_num)), 2))
        self.comb_fc2 = nn.Linear(512 * len(self.combination_list), class_num)
        nn.init.kaiming_normal_(self.cnn_last[0].weight)
        nn.init.kaiming_normal_(self.comb_fc1[0].weight)
        nn.init.kaiming_normal_(self.comb_fc2.weight)
Exemplo n.º 2
0
    def __init__(self, pretrained: bool = True, resnet_3d_pretrained_model='r3d18_KM_200ep.pth'):
        super().__init__()

        resnet18_3d = generate_model(18)
        if pretrained:
            checkpoint = torch.load(resnet_3d_pretrained_model)
            resnet18_3d.fc = nn.Linear(512, 1039)
            resnet18_3d.load_state_dict(checkpoint['state_dict'])
        self.resnet18_3d = resnet18_3d
        resnet18_3d_last_dim = 512
        # self.fc = nn.Linear(resnet18_3d_last_dim * 2, 2)
        # nn.init.kaiming_normal_(self.fc.weight)
        self.fc1 = nn.Linear(resnet18_3d_last_dim * 2, 4096)
        self.relu1 = nn.ReLU(inplace=True)
        self.fc2 = nn.Linear(4096, 4096)
        self.relu2 = nn.ReLU(inplace=True)
        self.fc3 = nn.Linear(4096, 2)
        nn.init.kaiming_normal_(self.fc1.weight)
        nn.init.kaiming_normal_(self.fc2.weight)
        nn.init.kaiming_normal_(self.fc3.weight)
Exemplo n.º 3
0
    def __init__(self, class_num: int = 101, bidirectional: bool = True, pretrained: bool = True,
                 resnet_3d_pretrained_model='r3d18_KM_200ep.pth'):
        super().__init__()

        resnet18_3d = generate_model(18)
        if pretrained:
            checkpoint = torch.load(resnet_3d_pretrained_model)
            resnet18_3d.fc = nn.Linear(512, 1039)
            resnet18_3d.load_state_dict(checkpoint['state_dict'])

        self.resnet18_3d = resnet18_3d

        resnet18_3d_last_dim = 512

        lstm_dim = 512
        if bidirectional:
            # self.rnn = nn.LSTM(resnet18_3d_last_dim, lstm_dim // 2, bidirectional=True, num_layers=2)
            self.rnn = nn.GRU(resnet18_3d_last_dim, lstm_dim // 2, bidirectional=True, num_layers=2)
        else:
            # self.rnn = nn.LSTM(resnet18_3d_last_dim, lstm_dim, bidirectional=False, num_layers=2)
            self.rnn = nn.GRU(resnet18_3d_last_dim, lstm_dim, bidirectional=False, num_layers=2)

        self.fc = nn.Linear(lstm_dim, class_num)
        nn.init.kaiming_normal_(self.fc.weight)
Exemplo n.º 4
0
    ),
    batch_size=batch_size,
    shuffle=True)
test_loader = DataLoader(
    VideoTestDataSet(
        frame_num=frame_num,
        path_list=generate_path_list(args, 'test'),
        frame_interval=0
    ),
    batch_size=batch_size,
    shuffle=False)
train_iterate_len = len(train_loader)
test_iterate_len = len(test_loader)

# 初期設定
Net = generate_model(18)
if args.use_pretrained_model:
    checkpoint = torch.load('r3d18_KM_200ep.pth')
    Net.fc = nn.Linear(512, 1039)
    Net.load_state_dict(checkpoint['state_dict'])
Net.fc = nn.Linear(512, args.class_num)
criterion = nn.CrossEntropyLoss()  # Loss関数を定義
optimizer = torch.optim.Adam(Net.parameters(), lr=args.learning_rate)  # 重み更新方法を定義
current_epoch = 0

# ログファイルの生成
if not args.no_reset_log_file:
    with open(log_train_path, mode='w') as f:
        f.write('epoch,loss,accuracy,time,learning_rate\n')
    with open(log_test_path, mode='w') as f:
        f.write('epoch,loss,accuracy,time,learning_rate\n')
Exemplo n.º 5
0
def generate_model(opt):
    assert opt.model in [
        'resnet', 'resnet2p1d', 'preresnet', 'wideresnet', 'resnext',
        'densenet'
    ]

    if opt.model == 'resnet':
        model = resnet_3d.generate_model(
            model_depth=opt.model_depth,  # 50
            n_classes=opt.n_classes,  # class num
            n_input_channels=opt.n_input_channels,  # 3
            shortcut_type=opt.resnet_shortcut,  # A or B default B
            conv1_t_size=opt.conv1_t_size,  # default 7
            conv1_t_stride=opt.conv1_t_stride,  # 步幅 default=1
            no_max_pool=opt.no_max_pool,
            widen_factor=opt.resnet_widen_factor)  # default 1
    # elif opt.model == 'resnet2p1d':
    #     model = resnet2p1d.generate_model(model_depth=opt.model_depth,
    #                                       n_classes=opt.n_classes,
    #                                       n_input_channels=opt.n_input_channels,
    #                                       shortcut_type=opt.resnet_shortcut,
    #                                       conv1_t_size=opt.conv1_t_size,
    #                                       conv1_t_stride=opt.conv1_t_stride,
    #                                       no_max_pool=opt.no_max_pool,
    #                                       widen_factor=opt.resnet_widen_factor)
    # elif opt.model == 'wideresnet':
    #     model = wide_resnet.generate_model(
    #         model_depth=opt.model_depth,
    #         k=opt.wide_resnet_k,
    #         n_classes=opt.n_classes,
    #         n_input_channels=opt.n_input_channels,
    #         shortcut_type=opt.resnet_shortcut,
    #         conv1_t_size=opt.conv1_t_size,
    #         conv1_t_stride=opt.conv1_t_stride,
    #         no_max_pool=opt.no_max_pool)
    # elif opt.model == 'resnext':
    #     model = resnext.generate_model(model_depth=opt.model_depth,
    #                                    cardinality=opt.resnext_cardinality,
    #                                    n_classes=opt.n_classes,
    #                                    n_input_channels=opt.n_input_channels,
    #                                    shortcut_type=opt.resnet_shortcut,
    #                                    conv1_t_size=opt.conv1_t_size,
    #                                    conv1_t_stride=opt.conv1_t_stride,
    #                                    no_max_pool=opt.no_max_pool)
    # elif opt.model == 'preresnet':
    #     model = pre_act_resnet.generate_model(
    #         model_depth=opt.model_depth,
    #         n_classes=opt.n_classes,
    #         n_input_channels=opt.n_input_channels,
    #         shortcut_type=opt.resnet_shortcut,
    #         conv1_t_size=opt.conv1_t_size,
    #         conv1_t_stride=opt.conv1_t_stride,
    #         no_max_pool=opt.no_max_pool)
    # elif opt.model == 'densenet':
    #     model = densenet.generate_model(model_depth=opt.model_depth,
    #                                     n_classes=opt.n_classes,
    #                                     n_input_channels=opt.n_input_channels,
    #                                     conv1_t_size=opt.conv1_t_size,
    #                                     conv1_t_stride=opt.conv1_t_stride,
    #                                     no_max_pool=opt.no_max_pool)

    return model