test_loader = DataLoader(VideoSortingPerFrameTrainDataSet(
    frame_num=frame_num,
    path_list=ucf101_test_path_load(args.dataset_path, args.test_label_path,
                                    args.class_path),
    interval_frame=frame_interval,
    sorting_start_index=sorting_start_index,
    sorting_end_index=sorting_end_index),
                         batch_size=batch_size,
                         shuffle=False)
train_iterate_len = len(train_loader)
test_iterate_len = len(test_loader)

# 初期設定
# resnet18を取得
Net = CNN_RNN(frame_num,
              pretrained=args.use_pretrained_model,
              bidirectional=args.use_bidirectional)
criterion = torch.nn.CrossEntropyLoss()  # Loss関数を定義
optimizer = torch.optim.Adam(Net.parameters(),
                             lr=args.learning_rate)  # 重み更新方法を定義
current_epoch = 0

# ログファイルの生成
if not args.no_reset_log_file:
    with open(log_train_path, mode='w') as f:
        f.write(
            'epoch,loss,full_fit_accuracy,per_fit_accuracy,sorting_accuracy,time,learning_rate\n'
        )
    with open(log_test_path, mode='w') as f:
        f.write(
            'epoch,loss,full_fit_accuracy,per_fit_accuracy,sorting_accuracy,time,learning_rate\n'
示例#2
0
        frame_num=frame_num,
        path_list=ucf101_train_path_load(args.dataset_path, args.train_label_path),
        frame_interval=interval_frames),
    batch_size=batch_size, shuffle=True)
test_loader = DataLoader(
    VideoSortTrainDataSet(
        frame_num=frame_num,
        path_list=ucf101_test_path_load(args.dataset_path, args.test_label_path, args.class_path),
        frame_interval=interval_frames),
    batch_size=batch_size, shuffle=False)
train_iterate_len = len(train_loader)
test_iterate_len = len(test_loader)

# 初期設定
# resnet18を取得
Net = CNN_RNN(args.frame_num, pretrained=args.use_pretrained_model, bidirectional=args.use_bidirectional,
              task='classification')
criterion = torch.nn.CrossEntropyLoss()  # Loss関数を定義
optimizer = torch.optim.Adam(Net.parameters(), lr=args.learning_rate)  # 重み更新方法を定義
current_epoch = 0

# ログファイルの生成
if not args.no_reset_log_file:
    with open(log_train_path, mode='w') as f:
        f.write('epoch,loss,full_fit_accuracy,per_fit_accuracy,time,learning_rate\n')
    with open(log_test_path, mode='w') as f:
        f.write('epoch,loss,full_fit_accuracy,per_fit_accuracy,time,learning_rate\n')

# CUDA環境の有無で処理を変更
if args.use_cuda:
    criterion = criterion.cuda()
    Net = torch.nn.DataParallel(Net.cuda())