예제 #1
0
파일: submit.py 프로젝트: xdw1995/Fall_Down
def submit_video_2():
    dic = {}
    dic[0] = 7
    dic[1] = 11
    dic[2] = 6
    dic[3] = 8
    dic[4] = 10
    model = TSN(5, 5, modality='RGB', partial_bn=True, is_shift=True)

    pretrained_dict = torch.load(
        '/data/xudw/temporal-shift-module-master/Fall_down_5_frame_5.13.pth.tar',
        map_location='cpu')
    try:
        model_dict = model.module.state_dict()
    except AttributeError:
        model_dict = model.state_dict()

    pretrained_dict = {
        k: v
        for k, v in pretrained_dict.items() if k in model_dict
    }
    print("load pretrain model")

    model_dict.update(pretrained_dict)
    model.load_state_dict(model_dict)

    model = model.cuda()
    model.eval()
    path = '/data/xudw/labels_submit_video_camera2/'  #come from the singe frame label
    import os
    import csv
    import re
    import cv2
    import pandas as pd
    import numpy as np
    with torch.no_grad():
        for Subject in os.listdir(path):
            c1 = re.findall('(\d+)', Subject)
            Subject_id = c1[0]
            for Activity in os.listdir(path + Subject):
                c2 = re.findall('(\d+)', Activity)
                Activity_id = c2[0]

                for csv_1 in os.listdir(path + Subject + '/' + Activity):
                    csv_data = pd.read_csv(path + Subject + '/' + Activity +
                                           '/' + csv_1)
                    c3 = re.findall('(\d+)', csv_1)
                    Trial_id = c3[-1]
                    loc = np.array(csv_data["Timestamp"])
                    loc = list(loc)

                    label = np.array(csv_data['Tag'])
                    label = list(label)
                    print(path + Subject + '/' + Activity + '/' + csv_1)

                    # csv_data.close()
                    f = open(path + Subject + '/' + Activity + '/' + csv_1,
                             'w',
                             encoding='utf-8',
                             newline='')
                    # 2. 基于文件对象构建 csv写入对象
                    csv_writer = csv.writer(f)
                    # 3. 构建列表头
                    csv_writer.writerow(["Timestamp", "Tag"])
                    frames = sorted(
                        os.listdir('/data/xudw/Fall_Down_data/' + 'Subject' +
                                   Subject_id + 'Activity' + Activity_id +
                                   'Trial' + Trial_id + 'Camera2'))
                    for i in range(len(loc)):
                        if 0 <= i <= 3:
                            print(loc[i], label[i])
                            csv_writer.writerow([loc[i], label[i]])
                        else:
                            buffer = []
                            frame_1 = cv2.imread('/data/xudw/Fall_Down_data/' +
                                                 'Subject' + Subject_id +
                                                 'Activity' + Activity_id +
                                                 'Trial' + Trial_id +
                                                 'Camera2/' + frames[i - 4])
                            frame_1 = frame_1[73:459, 208:394, :]

                            # if a == '2':
                            #     frame = cv2.imread(i)
                            #     frame = frame[73:459, 208:394, :]

                            frame_1 = cv2.cvtColor(frame_1, cv2.COLOR_BGR2RGB)
                            buffer.append(frame_1)

                            frame_2 = cv2.imread('/data/xudw/Fall_Down_data/' +
                                                 'Subject' + Subject_id +
                                                 'Activity' + Activity_id +
                                                 'Trial' + Trial_id +
                                                 'Camera2/' + frames[i - 3])
                            frame_2 = frame_2[73:459, 208:394, :]
                            frame_2 = cv2.cvtColor(frame_2, cv2.COLOR_BGR2RGB)
                            buffer.append(frame_2)
                            frame_3 = cv2.imread('/data/xudw/Fall_Down_data/' +
                                                 'Subject' + Subject_id +
                                                 'Activity' + Activity_id +
                                                 'Trial' + Trial_id +
                                                 'Camera2/' + frames[i - 2])
                            frame_3 = frame_3[73:459, 208:394, :]
                            frame_3 = cv2.cvtColor(frame_3, cv2.COLOR_BGR2RGB)
                            buffer.append(frame_3)
                            frame_4 = cv2.imread('/data/xudw/Fall_Down_data/' +
                                                 'Subject' + Subject_id +
                                                 'Activity' + Activity_id +
                                                 'Trial' + Trial_id +
                                                 'Camera2/' + frames[i - 1])
                            frame_4 = frame_4[73:459, 208:394, :]
                            frame_4 = cv2.cvtColor(frame_4, cv2.COLOR_BGR2RGB)
                            buffer.append(frame_4)
                            frame_5 = cv2.imread('/data/xudw/Fall_Down_data/' +
                                                 'Subject' + Subject_id +
                                                 'Activity' + Activity_id +
                                                 'Trial' + Trial_id +
                                                 'Camera2/' + frames[i])
                            frame_5 = frame_5[73:459, 208:394, :]
                            frame_5 = cv2.cvtColor(frame_5, cv2.COLOR_BGR2RGB)
                            buffer.append(frame_5)

                            temp = []
                            for j in buffer:
                                temp.append(j)
                            buffer = np.concatenate(temp, axis=2)
                            labels = model(
                                test_transform(buffer).unsqueeze(dim=0).cuda())
                            # print(label)
                            _, predicted = torch.max(labels.data, 1)
                            print(loc[i])
                            print(label[i])
                            print("----" * 10)
                            print(loc[i], dic[int(predicted)])
                            csv_writer.writerow(
                                [loc[i], str(dic[int(predicted)])])
예제 #2
0
파일: submit.py 프로젝트: xdw1995/Fall_Down
def submit_2():
    dic = {}
    dic[0] = 7
    dic[1] = 11
    dic[2] = 6
    dic[3] = 8
    dic[4] = 10
    model = TSN(5,
                1,
                base_model='resnet34',
                modality='RGB',
                partial_bn=True,
                is_shift=False)
    pretrained_dict = torch.load(
        '/data/xudw/temporal-shift-module-master/Fall_down_Single_test111_5.13.pth.tar',
        map_location='cpu')
    try:
        model_dict = model.module.state_dict()
    except AttributeError:
        model_dict = model.state_dict()

    pretrained_dict = {
        k: v
        for k, v in pretrained_dict.items() if k in model_dict
    }
    print("load pretrain model")

    model_dict.update(pretrained_dict)
    model.load_state_dict(model_dict)

    model = model.cuda()
    model.eval()
    path = '/data/xudw/labels_submit_2/'
    import os
    import csv
    import re
    import cv2
    import pandas as pd
    import numpy as np
    with torch.no_grad():
        for Subject in os.listdir(path):
            c1 = re.findall('(\d+)', Subject)
            Subject_id = c1[0]
            for Activity in os.listdir(path + Subject):
                c2 = re.findall('(\d+)', Activity)
                Activity_id = c2[0]
                for csv_1 in os.listdir(path + Subject + '/' + Activity):
                    csv_data = pd.read_csv(path + Subject + '/' + Activity +
                                           '/' + csv_1)
                    c3 = re.findall('(\d+)', csv_1)
                    Trial_id = c3[-1]
                    loc = np.array(csv_data["Timestamp"])
                    loc = list(loc)
                    label = np.array(csv_data['Tag'])
                    label = list(label)
                    # csv_data.close()
                    f = open(path + Subject + '/' + Activity + '/' + csv_1,
                             'w',
                             encoding='utf-8',
                             newline='')
                    # 2. 基于文件对象构建 csv写入对象

                    csv_writer = csv.writer(f)
                    # 3. 构建列表头
                    csv_writer.writerow(["Timestamp", "Tag"])
                    frames = sorted(
                        os.listdir('/data/xudw/Fall_Down_data/' + 'Subject' +
                                   Subject_id + 'Activity' + Activity_id +
                                   'Trial' + Trial_id + 'Camera2'))
                    print(path + Subject + '/' + Activity + '/' + csv_1)
                    for i in range(len(loc)):
                        frame = cv2.imread('/data/xudw/Fall_Down_data/' +
                                           'Subject' + Subject_id +
                                           'Activity' + Activity_id + 'Trial' +
                                           Trial_id + 'Camera2/' + frames[i])
                        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                        labels = model(
                            test_transform(frame).unsqueeze(dim=0).cuda())
                        # print(label)

                        _, predicted = torch.max(labels.data, 1)
                        print(label[i])
                        print(dic[int(predicted)])
                        csv_writer.writerow([loc[i], str(dic[int(predicted)])])