def main():
    global args, best_prec1, model, writer, best_loss, length, width, height, input_size, scheduler
    args = parser.parse_args()
    training_continue = args.contine
    if '3D' in args.arch:
        if 'I3D' in args.arch or 'MFNET3D' in args.arch:
            if '112' in args.arch:
                scale = 0.5
            else:
                scale = 1
        else:
            if '224' in args.arch:
                scale = 1
            else:
                scale = 0.5
    elif 'r2plus1d' in args.arch:
        scale = 0.5
    else:
        scale = 1

    print('scale: %.1f' % (scale))

    input_size = int(224 * scale)
    width = int(340 * scale)
    height = int(256 * scale)

    saveLocation = "./checkpoint/" + args.dataset + "_" + args.arch + "_split" + str(
        args.split)
    if not os.path.exists(saveLocation):
        os.makedirs(saveLocation)
    writer = SummaryWriter(saveLocation)

    # create model

    if args.evaluate:
        print("Building validation model ... ")
        model = build_model_validate()
        optimizer = AdamW(model.parameters(),
                          lr=args.lr,
                          weight_decay=args.weight_decay)
    elif training_continue:
        model, startEpoch, optimizer, best_prec1 = build_model_continue()
        for param_group in optimizer.param_groups:
            lr = param_group['lr']
            #param_group['lr'] = lr
        print(
            "Continuing with best precision: %.3f and start epoch %d and lr: %f"
            % (best_prec1, startEpoch, lr))
    else:
        print("Building model with ADAMW... ")
        model = build_model()
        optimizer = AdamW(model.parameters(),
                          lr=args.lr,
                          weight_decay=args.weight_decay)
        startEpoch = 0

    if HALF:
        model.half()  # convert to half precision
        for layer in model.modules():
            if isinstance(layer, nn.BatchNorm2d):
                layer.float()

    print("Model %s is loaded. " % (args.arch))

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()
    criterion2 = nn.MSELoss().cuda()

    scheduler = lr_scheduler.ReduceLROnPlateau(optimizer,
                                               'min',
                                               patience=5,
                                               verbose=True)

    print("Saving everything to directory %s." % (saveLocation))
    if args.dataset == 'ucf101':
        dataset = './datasets/ucf101_frames'
    elif args.dataset == 'hmdb51':
        dataset = './datasets/hmdb51_frames'
    elif args.dataset == 'smtV2':
        dataset = './datasets/smtV2_frames'
    elif args.dataset == 'window':
        dataset = './datasets/window_frames'
    elif args.dataset == 'haa500_basketball':
        dataset = './datasets/haa500_basketball_frames'
    else:
        print("No convenient dataset entered, exiting....")
        return 0

    cudnn.benchmark = True
    modality = args.arch.split('_')[0]
    if "3D" in args.arch or 'tsm' in args.arch or 'slowfast' in args.arch or 'r2plus1d' in args.arch:
        if '64f' in args.arch:
            length = 64
        elif '32f' in args.arch:
            length = 32
        else:
            length = 16
    else:
        length = 1
    # Data transforming
    if modality == "rgb" or modality == "pose":
        is_color = True
        scale_ratios = [1.0, 0.875, 0.75, 0.66]
        if 'I3D' in args.arch:
            if 'resnet' in args.arch:
                clip_mean = [0.45, 0.45, 0.45] * args.num_seg * length
                clip_std = [0.225, 0.225, 0.225] * args.num_seg * length
            else:
                clip_mean = [0.5, 0.5, 0.5] * args.num_seg * length
                clip_std = [0.5, 0.5, 0.5] * args.num_seg * length
            #clip_std = [0.25, 0.25, 0.25] * args.num_seg * length
        elif 'MFNET3D' in args.arch:
            clip_mean = [0.48627451, 0.45882353, 0.40784314
                         ] * args.num_seg * length
            clip_std = [0.234, 0.234, 0.234] * args.num_seg * length
        elif "3D" in args.arch:
            clip_mean = [114.7748, 107.7354, 99.4750] * args.num_seg * length
            clip_std = [1, 1, 1] * args.num_seg * length
        elif "r2plus1d" in args.arch:
            clip_mean = [0.43216, 0.394666, 0.37645] * args.num_seg * length
            clip_std = [0.22803, 0.22145, 0.216989] * args.num_seg * length
        elif "rep_flow" in args.arch:
            clip_mean = [0.5, 0.5, 0.5] * args.num_seg * length
            clip_std = [0.5, 0.5, 0.5] * args.num_seg * length
        elif "slowfast" in args.arch:
            clip_mean = [0.45, 0.45, 0.45] * args.num_seg * length
            clip_std = [0.225, 0.225, 0.225] * args.num_seg * length
        else:
            clip_mean = [0.485, 0.456, 0.406] * args.num_seg * length
            clip_std = [0.229, 0.224, 0.225] * args.num_seg * length
    elif modality == "pose":
        is_color = True
        scale_ratios = [1.0, 0.875, 0.75, 0.66]
        clip_mean = [0.485, 0.456, 0.406] * args.num_seg
        clip_std = [0.229, 0.224, 0.225] * args.num_seg
    elif modality == "flow":
        is_color = False
        scale_ratios = [1.0, 0.875, 0.75, 0.66]
        if 'I3D' in args.arch:
            clip_mean = [0.5, 0.5] * args.num_seg * length
            clip_std = [0.5, 0.5] * args.num_seg * length
        elif "3D" in args.arch:
            clip_mean = [127.5, 127.5] * args.num_seg * length
            clip_std = [1, 1] * args.num_seg * length
        else:
            clip_mean = [0.5, 0.5] * args.num_seg * length
            clip_std = [0.226, 0.226] * args.num_seg * length
    elif modality == "both":
        is_color = True
        scale_ratios = [1.0, 0.875, 0.75, 0.66]
        clip_mean = [0.485, 0.456, 0.406, 0.5, 0.5] * args.num_seg * length
        clip_std = [0.229, 0.224, 0.225, 0.226, 0.226] * args.num_seg * length
    else:
        print("No such modality. Only rgb and flow supported.")

    normalize = video_transforms.Normalize(mean=clip_mean, std=clip_std)

    if "3D" in args.arch and not ('I3D' in args.arch):
        train_transform = video_transforms.Compose([
            video_transforms.MultiScaleCrop((input_size, input_size),
                                            scale_ratios),
            video_transforms.RandomHorizontalFlip(),
            video_transforms.ToTensor2(),
            normalize,
        ])

        val_transform = video_transforms.Compose([
            video_transforms.CenterCrop((input_size)),
            video_transforms.ToTensor2(),
            normalize,
        ])
    else:
        train_transform = video_transforms.Compose([
            video_transforms.MultiScaleCrop((input_size, input_size),
                                            scale_ratios),
            video_transforms.RandomHorizontalFlip(),
            video_transforms.ToTensor(),
            normalize,
        ])

        val_transform = video_transforms.Compose([
            video_transforms.CenterCrop((input_size)),
            video_transforms.ToTensor(),
            normalize,
        ])

    # data loading
    train_setting_file = "train_%s_split%d.txt" % (modality, args.split)
    train_split_file = os.path.join(args.settings, args.dataset,
                                    train_setting_file)
    val_setting_file = "val_%s_split%d.txt" % (modality, args.split)
    val_split_file = os.path.join(args.settings, args.dataset,
                                  val_setting_file)
    if not os.path.exists(train_split_file) or not os.path.exists(
            val_split_file):
        print(
            "No split file exists in %s directory. Preprocess the dataset first"
            % (args.settings))

    train_dataset = datasets.__dict__[args.dataset](
        root=dataset,
        source=train_split_file,
        phase="train",
        modality=modality,
        is_color=is_color,
        new_length=length,
        new_width=width,
        new_height=height,
        video_transform=train_transform,
        num_segments=args.num_seg)

    val_dataset = datasets.__dict__[args.dataset](
        root=dataset,
        source=val_split_file,
        phase="val",
        modality=modality,
        is_color=is_color,
        new_length=length,
        new_width=width,
        new_height=height,
        video_transform=val_transform,
        num_segments=args.num_seg)

    print('{} samples found, {} train samples and {} test samples.'.format(
        len(val_dataset) + len(train_dataset), len(train_dataset),
        len(val_dataset)))

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    if args.evaluate:
        prec1, prec3, _ = validate(val_loader, model, criterion, criterion2,
                                   modality)
        return

    for epoch in range(startEpoch, args.epochs):
        #        if learning_rate_index > max_learning_rate_decay_count:
        #            break
        #        adjust_learning_rate(optimizer, epoch)
        train(train_loader, model, criterion, criterion2, optimizer, epoch,
              modality)

        # evaluate on validation set
        prec1 = 0.0
        lossClassification = 0
        if (epoch + 1) % args.save_freq == 0:
            prec1, prec3, lossClassification = validate(
                val_loader, model, criterion, criterion2, modality)
            writer.add_scalar('data/top1_validation', prec1, epoch)
            writer.add_scalar('data/top3_validation', prec3, epoch)
            writer.add_scalar('data/classification_loss_validation',
                              lossClassification, epoch)
            scheduler.step(lossClassification)
        # remember best prec@1 and save checkpoint

        is_best = prec1 >= best_prec1
        best_prec1 = max(prec1, best_prec1)
        #        best_in_existing_learning_rate = max(prec1, best_in_existing_learning_rate)
        #
        #        if best_in_existing_learning_rate > prec1 + 1:
        #            learning_rate_index = learning_rate_index
        #            best_in_existing_learning_rate = 0

        if (epoch + 1) % args.save_freq == 0:
            checkpoint_name = "%03d_%s" % (epoch + 1, "checkpoint.pth.tar")
            if is_best:
                print("Model works well")
                save_checkpoint(
                    {
                        'epoch': epoch + 1,
                        'arch': args.arch,
                        'state_dict': model.state_dict(),
                        'best_prec1': best_prec1,
                        'best_loss': best_loss,
                        'optimizer': optimizer.state_dict(),
                    }, is_best, checkpoint_name, saveLocation)

    checkpoint_name = "%03d_%s" % (epoch + 1, "checkpoint.pth.tar")
    save_checkpoint(
        {
            'epoch': epoch + 1,
            'arch': args.arch,
            'state_dict': model.state_dict(),
            'best_prec1': best_prec1,
            'best_loss': best_loss,
            'optimizer': optimizer.state_dict(),
        }, is_best, checkpoint_name, saveLocation)
    writer.export_scalars_to_json("./all_scalars.json")
    writer.close()
示例#2
0
def main(args):
    global best_prec1, best_loss

    input_size = int(224 * args.scale)
    width = int(340 * args.scale)
    height = int(256 * args.scale)

    if not os.path.exists(args.savelocation):
        os.makedirs(args.savelocation)

    now = time.time()
    savelocation = os.path.join(args.savelocation, str(now))
    os.makedirs(savelocation)

    logging.basicConfig(filename=os.path.join(savelocation, "log.log"),
                        level=logging.INFO)

    model = build_model(args.arch, args.pre, args.num_seg, args.resume)
    optimizer = AdamW(model.parameters(),
                      lr=args.lr,
                      weight_decay=args.weight_decay)

    criterion = nn.CrossEntropyLoss().cuda()
    criterion2 = nn.MSELoss().cuda()

    scheduler = lr_scheduler.ReduceLROnPlateau(optimizer,
                                               'min',
                                               patience=5,
                                               verbose=True)

    # if args.dataset=='sign':
    #     dataset="/data/AUTSL/train_img_c"
    # elif args.dataset=="signd":
    #     dataset="/data/AUTSL/train_img_c"
    # elif args.dataset=="customd":
    #     dataset="/data/AUTSL/train_img_c"
    # else:
    #     print("no dataset")
    #     return 0

    cudnn.benchmark = True
    length = 64

    scale_ratios = [1.0, 0.875, 0.75, 0.66]

    clip_mean = [0.43216, 0.394666, 0.37645] * args.num_seg * length
    clip_std = [0.22803, 0.22145, 0.216989] * args.num_seg * length

    normalize = video_transforms.Normalize(mean=clip_mean, std=clip_std)

    train_transform = video_transforms.Compose([
        video_transforms.CenterCrop(input_size),
        video_transforms.ToTensor2(),
        normalize,
    ])

    val_transform = video_transforms.Compose([
        video_transforms.CenterCrop((input_size)),
        video_transforms.ToTensor2(),
        normalize,
    ])

    # test_transform = video_transforms.Compose([
    #     video_transforms.CenterCrop((input_size)),
    #     video_transforms.ToTensor2(),
    #     normalize,
    # ])
    # test_file = os.path.join(args.datasetpath, args.testlist)

    if not os.path.exists(args.trainlist) or not os.path.exists(args.vallist):
        print(
            "No split file exists in %s directory. Preprocess the dataset first"
            % (args.datasetpath))

    train_dataset = datasets.__dict__[args.dataset](
        root=args.datasetpath,
        source=args.trainlist,
        phase="train",
        modality="rgb",
        is_color=True,
        new_length=length,
        new_width=width,
        new_height=height,
        video_transform=train_transform,
        num_segments=args.num_seg)

    val_dataset = datasets.__dict__[args.dataset](
        root=args.datasetpath,
        source=args.vallist,
        phase="val",
        modality="rgb",
        is_color=True,
        new_length=length,
        new_width=width,
        new_height=height,
        video_transform=val_transform,
        num_segments=args.num_seg)

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    best_prec1 = 0
    for epoch in range(0, args.epochs):
        train(length, input_size, train_loader, model, criterion, criterion2,
              optimizer, epoch)

        if (epoch + 1) % args.save_freq == 0:
            is_best = False
            prec1, prec3, lossClassification = validate(
                length, input_size, val_loader, model, criterion, criterion2)
            scheduler.step(lossClassification)

            if prec1 >= best_prec1:
                is_best = True
                best_prec1 = prec1

            checkpoint_name = "%03d_%s" % (epoch + 1, "checkpoint.pth.tar")
            text = "save checkpoint {}".format(checkpoint_name)
            print(text)
            logging.info(text)
            save_checkpoint(
                {
                    "epoch": epoch + 1,
                    "arch": args.arch,
                    "state_dict": model.state_dict(),
                    "prec1": prec1,
                    "optimizer": optimizer.state_dict()
                }, is_best, checkpoint_name, savelocation)
示例#3
0
def VideoSpatialPrediction3D(vid_name,
                             net,
                             num_categories,
                             architecture_name,
                             start_frame=0,
                             num_frames=0,
                             length=16,
                             extension='img_{0:05d}.jpg',
                             ten_crop=False):

    if num_frames == 0:
        imglist = os.listdir(vid_name)
        newImageList = []
        if 'rgb' in architecture_name or 'pose' in architecture_name:
            for item in imglist:
                if 'img' in item:
                    newImageList.append(item)
        elif 'flow' in architecture_name:
            for item in imglist:
                if 'flow_x' in item:
                    newImageList.append(item)
        duration = len(newImageList)
    else:
        duration = num_frames

    if 'rgb' in architecture_name or 'pose' in architecture_name:
        if 'I3D' in architecture_name:

            if not 'resnet' in architecture_name:
                clip_mean = [0.5, 0.5, 0.5]
                clip_std = [0.5, 0.5, 0.5]
            else:
                clip_mean = [0.45, 0.45, 0.45]
                clip_std = [0.225, 0.225, 0.225]
            normalize = video_transforms.Normalize(mean=clip_mean,
                                                   std=clip_std)
            val_transform = video_transforms.Compose([
                video_transforms.ToTensor(),
                normalize,
            ])
            if '112' in architecture_name:
                scale = 0.5
            else:
                scale = 1
        elif 'MFNET3D' in architecture_name:
            clip_mean = [0.48627451, 0.45882353, 0.40784314]
            clip_std = [0.234, 0.234, 0.234]
            normalize = video_transforms.Normalize(mean=clip_mean,
                                                   std=clip_std)
            val_transform = video_transforms.Compose(
                [video_transforms.ToTensor(), normalize])
            if '112' in architecture_name:
                scale = 0.5
            else:
                scale = 1
        elif 'tsm' in architecture_name:
            clip_mean = [0.485, 0.456, 0.406]
            clip_std = [0.229, 0.224, 0.225]
            normalize = video_transforms.Normalize(mean=clip_mean,
                                                   std=clip_std)
            val_transform = video_transforms.Compose(
                [video_transforms.ToTensor(), normalize])
            scale = 1
        elif "r2plus1d" in architecture_name:
            clip_mean = [0.43216, 0.394666, 0.37645]
            clip_std = [0.22803, 0.22145, 0.216989]
            normalize = video_transforms.Normalize(mean=clip_mean,
                                                   std=clip_std)
            val_transform = video_transforms.Compose(
                [video_transforms.ToTensor(), normalize])
            scale = 0.5
        elif 'rep_flow' in architecture_name:
            clip_mean = [0.5, 0.5, 0.5]
            clip_std = [0.5, 0.5, 0.5]

            normalize = video_transforms.Normalize(mean=clip_mean,
                                                   std=clip_std)
            val_transform = video_transforms.Compose([
                video_transforms.ToTensor(),
                normalize,
            ])
            scale = 1
        elif "slowfast" in architecture_name:
            clip_mean = [0.45, 0.45, 0.45]
            clip_std = [0.225, 0.225, 0.225]
            normalize = video_transforms.Normalize(mean=clip_mean,
                                                   std=clip_std)
            val_transform = video_transforms.Compose([
                video_transforms.ToTensor(),
                normalize,
            ])
            scale = 1
        else:
            scale = 0.5
            clip_mean = [114.7748, 107.7354, 99.4750]
            clip_std = [1, 1, 1]
            normalize = video_transforms.Normalize(mean=clip_mean,
                                                   std=clip_std)
            val_transform = video_transforms.Compose([
                video_transforms.ToTensor2(),
                normalize,
            ])
    elif 'flow' in architecture_name:
        if 'I3D' in architecture_name:
            clip_mean = [0.5] * 2
            clip_std = [0.5] * 2
            normalize = video_transforms.Normalize(mean=clip_mean,
                                                   std=clip_std)

            val_transform = video_transforms.Compose([
                video_transforms.ToTensor(),
                normalize,
            ])
            scale = 1
        elif "3D" in architecture_name:
            scale = 0.5
            clip_mean = [127.5, 127.5]
            clip_std = [1, 1]
            normalize = video_transforms.Normalize(mean=clip_mean,
                                                   std=clip_std)
            val_transform = video_transforms.Compose([
                video_transforms.ToTensor2(),
                normalize,
            ])
        elif "r2plus1d" in architecture_name:
            clip_mean = [0.5] * 2
            clip_std = [0.226] * 2
            normalize = video_transforms.Normalize(mean=clip_mean,
                                                   std=clip_std)

            val_transform = video_transforms.Compose([
                video_transforms.ToTensor(),
                normalize,
            ])
            scale = 0.5

    if '224' in architecture_name:
        scale = 1
    if '112' in architecture_name:
        scale = 0.5
    # selection
    #step = int(math.floor((duration-1)/(num_samples-1)))
    dims2 = (224, 224, 3, duration)

    imageSize = int(224 * scale)
    dims = (int(256 * scale), int(340 * scale), 3, duration)
    #dims = (int(256 * scale),int(256 * scale),3,duration)
    duration = duration - 1

    offsets = []

    offsetMainIndexes = list(range(1, duration - length, length))
    if len(offsetMainIndexes) == 0:
        offsets = list(range(1, duration + 2)) * int(
            np.floor(length / (duration + 1))) + list(
                range(1, length % (duration + 1) + 1))
    else:
        shift = int((duration - (offsetMainIndexes[-1] + length)) / 2)
        for mainOffsetValue in offsetMainIndexes:
            for lengthID in range(1, length + 1):
                offsets.append(lengthID + mainOffsetValue + shift)

#    offsetMainIndexes = list(range(0,duration,length))
#    for mainOffsetValue in offsetMainIndexes:
#        for lengthID in range(1, length+1):
#            loaded_frame_index = lengthID + mainOffsetValue
#            moded_loaded_frame_index = loaded_frame_index % (duration + 1)
#            if moded_loaded_frame_index == 0:
#                moded_loaded_frame_index = (duration + 1)
#            offsets.append(moded_loaded_frame_index)

    imageList = []
    imageList1 = []
    imageList2 = []
    imageList3 = []
    imageList4 = []
    imageList5 = []
    imageList6 = []
    imageList7 = []
    imageList8 = []
    imageList9 = []
    imageList10 = []
    imageList11 = []
    imageList12 = []
    interpolation = cv2.INTER_LINEAR

    for index in offsets:
        if 'rgb' in architecture_name or 'pose' in architecture_name:
            img_file = os.path.join(vid_name, extension.format(index))
            img = cv2.imread(img_file, cv2.IMREAD_UNCHANGED)

            img = cv2.resize(img, dims[1::-1], interpolation)

            #img2 = cv2.resize(img, dims2[1::-1],interpolation)
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            img_flip = img[:, ::-1, :].copy()
        elif 'flow' in architecture_name:
            flow_x_file = os.path.join(vid_name, extension.format('x', index))
            flow_y_file = os.path.join(vid_name, extension.format('y', index))
            img_x = cv2.imread(flow_x_file, cv2.IMREAD_GRAYSCALE)
            img_y = cv2.imread(flow_y_file, cv2.IMREAD_GRAYSCALE)
            img_x = np.expand_dims(img_x, -1)
            img_y = np.expand_dims(img_y, -1)
            img = np.concatenate((img_x, img_y), 2)
            img = cv2.resize(img, dims[1::-1], interpolation)
            img_flip = img[:, ::-1, :].copy()
        #img_flip2 = img2[:,::-1,:].copy()
        #imageList1.append(img[int(16 * scale):int(16 * scale + imageSize), int(16 * scale) : int(16 * scale + imageSize), :])
        imageList1.append(img[int(16 * scale):int(16 * scale + imageSize),
                              int(58 * scale):int(58 * scale + imageSize), :])
        imageList2.append(img[:imageSize, :imageSize, :])
        imageList3.append(img[:imageSize, -imageSize:, :])
        imageList4.append(img[-imageSize:, :imageSize, :])
        imageList5.append(img[-imageSize:, -imageSize:, :])
        imageList6.append(img_flip[int(16 * scale):int(16 * scale + imageSize),
                                   int(58 * scale):int(58 * scale +
                                                       imageSize), :])
        imageList7.append(img_flip[:imageSize, :imageSize, :])
        imageList8.append(img_flip[:imageSize, -imageSize:, :])
        imageList9.append(img_flip[-imageSize:, :imageSize, :])
        imageList10.append(img_flip[-imageSize:, -imageSize:, :])
#        imageList11.append(img2)
#        imageList12.append(img_flip2)

    if ten_crop:
        imageList = imageList1 + imageList2 + imageList3 + imageList4 + imageList5 + imageList6 + imageList7 + imageList8 + imageList9 + imageList10
    else:
        imageList = imageList1

    #imageList=imageList11+imageList12

    rgb_list = []

    for i in range(len(imageList)):
        cur_img = imageList[i]
        cur_img_tensor = val_transform(cur_img)
        rgb_list.append(np.expand_dims(cur_img_tensor.numpy(), 0))

    input_data = np.concatenate(rgb_list, axis=0)
    if 'rgb' in architecture_name or 'pose' in architecture_name:
        input_data = input_data.reshape(-1, length, 3, imageSize, imageSize)
    elif 'flow' in architecture_name:
        input_data = input_data.reshape(-1, length, 2, imageSize, imageSize)

    batch_size = 10
    result = np.zeros([input_data.shape[0], num_categories])
    num_batches = int(math.ceil(float(input_data.shape[0]) / batch_size))

    with torch.no_grad():
        for bb in range(num_batches):
            span = range(batch_size * bb,
                         min(input_data.shape[0], batch_size * (bb + 1)))
            input_data_batched = input_data[span, :, :, :, :]
            imgDataTensor = torch.from_numpy(input_data_batched).type(
                torch.FloatTensor).cuda()
            if 'rgb' in architecture_name or 'pose' in architecture_name:
                imgDataTensor = imgDataTensor.view(-1, length, 3, imageSize,
                                                   imageSize).transpose(1, 2)
            elif 'flow' in architecture_name:
                imgDataTensor = imgDataTensor.view(-1, length, 2, imageSize,
                                                   imageSize).transpose(1, 2)

            if 'bert' in architecture_name or 'pooling' in architecture_name or 'NLB' in architecture_name \
                or 'lstm' in architecture_name or 'adamw' in architecture_name:
                output, input_vectors, sequenceOut, maskSample = net(
                    imgDataTensor)
            else:
                output = net(imgDataTensor)
            #span = range(sample_size*bb, min(int(input_data.shape[0]/length),sample_size*(bb+1)))
            result[span, :] = output.data.cpu().numpy()
        mean_result = np.mean(result, 0)
        prediction = np.argmax(mean_result)
        top3 = mean_result.argsort()[::-1][:3]
        top5 = mean_result.argsort()[::-1][:5]

    return prediction, mean_result, top3
def VideoSpatialPrediction3D_bert(
        vid_name,
        net,
        num_categories,
        architecture_name,
        start_frame=0,
        num_frames=0,
        num_seg=4,
        length = 16,
        extension = 'img_{0:05d}.jpg',
        ten_crop = False
        ):

    if num_frames == 0:
        imglist = os.listdir(vid_name)
        newImageList=[]
        if 'rgb' in architecture_name or 'pose' in architecture_name:
            for item in imglist:
                if 'img' in item:
                   newImageList.append(item) 
        elif 'flow' in architecture_name:
            for item in imglist:
                if 'flow_x' in item:
                   newImageList.append(item) 
        duration = len(newImageList)
    else:
        duration = num_frames
    
    if 'rgb' in architecture_name:
        if 'I3D' in architecture_name:
            
            if not 'resnet' in architecture_name:
                clip_mean = [0.5, 0.5, 0.5] 
                clip_std = [0.5, 0.5, 0.5]
            else:
                clip_mean = [0.45, 0.45, 0.45]
                clip_std = [0.225, 0.225, 0.225] 
            normalize = video_transforms.Normalize(mean=clip_mean,
                                     std=clip_std)
            val_transform = video_transforms.Compose([
                    video_transforms.ToTensor(),
                    normalize,
                ])
            if '112' in architecture_name:
                scale = 0.5
            else:
                scale = 1
        elif 'MFNET3D' in architecture_name:
            clip_mean = [0.48627451, 0.45882353, 0.40784314]
            clip_std = [0.234, 0.234, 0.234] 
            normalize = video_transforms.Normalize(mean=clip_mean,
                                     std=clip_std)
            val_transform = video_transforms.Compose([
                    video_transforms.ToTensor(),
                    normalize])
            if '112' in architecture_name:
                scale = 0.5
            else:
                scale = 1
        elif 'tsm' in architecture_name:
            clip_mean = [0.485, 0.456, 0.406]
            clip_std = [0.229, 0.224, 0.225] 
            normalize = video_transforms.Normalize(mean=clip_mean,
                                     std=clip_std)
            val_transform = video_transforms.Compose([
                    video_transforms.ToTensor(),
                    normalize])
            scale = 1
        elif "r2plus1d" in architecture_name:
            clip_mean = [0.43216, 0.394666, 0.37645]
            clip_std = [0.22803, 0.22145, 0.216989]
            normalize = video_transforms.Normalize(mean=clip_mean,
                                     std=clip_std)
            val_transform = video_transforms.Compose([
                    video_transforms.ToTensor(),
                    normalize])
            scale = 0.5
        elif 'rep_flow' in architecture_name:
            clip_mean = [0.5, 0.5, 0.5] 
            clip_std = [0.5, 0.5, 0.5]
    
            normalize = video_transforms.Normalize(mean=clip_mean,
                                     std=clip_std)
            val_transform = video_transforms.Compose([
                    video_transforms.ToTensor(),
                    normalize,
                ])
            scale = 1
        elif "slowfast" in architecture_name:
            clip_mean = [0.45, 0.45, 0.45]
            clip_std = [0.225, 0.225, 0.225] 
            normalize = video_transforms.Normalize(mean=clip_mean,
                                     std=clip_std)
            val_transform = video_transforms.Compose([
                    video_transforms.ToTensor(),
                    normalize,       
                ])
            scale = 1
        else:
            scale = 0.5
            clip_mean = [114.7748, 107.7354, 99.4750]
            clip_std = [1, 1, 1]
            normalize = video_transforms.Normalize(mean=clip_mean,
                                     std=clip_std)
            val_transform = video_transforms.Compose([
                    video_transforms.ToTensor2(),
                    normalize,
                ])
    elif 'flow' in architecture_name:
        if 'I3D' in architecture_name:
            clip_mean = [0.5] * 2
            clip_std = [0.5] * 2
            normalize = video_transforms.Normalize(mean=clip_mean,
                                             std=clip_std)
            
            val_transform = video_transforms.Compose([
                    video_transforms.ToTensor(),
                    normalize,
                ])
            scale = 1
        else:
            scale = 0.5
            clip_mean = [127.5, 127.5]
            clip_std = [1, 1]
            normalize = video_transforms.Normalize(mean=clip_mean,
                                     std=clip_std)
            val_transform = video_transforms.Compose([
                    video_transforms.ToTensor2(),
                    normalize,
                ])

    # selection
    #step = int(math.floor((duration-1)/(num_samples-1)))
    if '224' in architecture_name:
        scale = 1
    if '112' in architecture_name:
        scale = 0.5

    imageSize=int(224 * scale)
    dims = (int(256 * scale),int(340 * scale),3,duration)
    duration = duration - 1
    average_duration = int(duration / num_seg)
    offsetMainIndexes = []
    offsets = []
    for seg_id in range(num_seg):
        if average_duration >= length:
            offsetMainIndexes.append(int((average_duration - length + 1)/2 + seg_id * average_duration))
        elif duration >=length:
            average_part_length = int(np.floor((duration-length)/num_seg))
            offsetMainIndexes.append(int((average_part_length*(seg_id) + average_part_length*(seg_id+1))/2))
        else:
            increase = int(duration / num_seg)
            offsetMainIndexes.append(0 + seg_id * increase)
    for mainOffsetValue in offsetMainIndexes:
        for lengthID in range(1, length+1):
            loaded_frame_index = lengthID + mainOffsetValue
            moded_loaded_frame_index = loaded_frame_index % (duration + 1)
            if moded_loaded_frame_index == 0:
                moded_loaded_frame_index = (duration + 1)
            offsets.append(moded_loaded_frame_index)
             
    imageList=[]
    imageList1=[]
    imageList2=[]
    imageList3=[]
    imageList4=[]    
    imageList5=[]  
    imageList6=[]
    imageList7=[]
    imageList8=[]
    imageList9=[]    
    imageList10=[] 
    imageList11=[] 
    imageList12=[] 
    interpolation = cv2.INTER_LINEAR
    
    for index in offsets:
        if 'rgb' in architecture_name or 'pose' in architecture_name:
            img_file = os.path.join(vid_name, extension.format(index))
            img = cv2.imread(img_file, cv2.IMREAD_UNCHANGED)
    
            img = cv2.resize(img, dims[1::-1],interpolation)
    
            #img2 = cv2.resize(img, dims2[1::-1],interpolation)
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            img_flip = img[:,::-1,:].copy()
        elif 'flow' in architecture_name:
            flow_x_file = os.path.join(vid_name, extension.format('x',index))
            flow_y_file = os.path.join(vid_name, extension.format('y',index))
            img_x = cv2.imread(flow_x_file, cv2.IMREAD_GRAYSCALE)
            img_y = cv2.imread(flow_y_file, cv2.IMREAD_GRAYSCALE)
            img_x = np.expand_dims(img_x,-1)
            img_y = np.expand_dims(img_y,-1)
            img = np.concatenate((img_x,img_y),2)    
            img = cv2.resize(img, dims[1::-1],interpolation)
            img_flip = img[:,::-1,:].copy()
        #img_flip2 = img2[:,::-1,:].copy()
        imageList1.append(img[int(16 * scale):int(16 * scale + imageSize), int(58 * scale) : int(58 * scale + imageSize), :])
        imageList2.append(img[:imageSize, :imageSize, :])
        imageList3.append(img[:imageSize, -imageSize:, :])
        imageList4.append(img[-imageSize:, :imageSize, :])
        imageList5.append(img[-imageSize:, -imageSize:, :])
        imageList6.append(img_flip[int(16 * scale):int(16 * scale + imageSize), int(58 * scale) : int(58 * scale + imageSize), :])
        imageList7.append(img_flip[:imageSize, :imageSize, :])
        imageList8.append(img_flip[:imageSize, -imageSize:, :])
        imageList9.append(img_flip[-imageSize:, :imageSize, :])
        imageList10.append(img_flip[-imageSize:, -imageSize:, :])
#        imageList11.append(img2)
#        imageList12.append(img_flip2)

    if ten_crop:
        imageList=imageList1+imageList2+imageList3+imageList4+imageList5+imageList6+imageList7+imageList8+imageList9+imageList10
    else:
        imageList=imageList1
    
    #imageList=imageList11+imageList12
    
    rgb_list=[]     

    for i in range(len(imageList)):
        cur_img = imageList[i]
        cur_img_tensor = val_transform(cur_img)
        rgb_list.append(np.expand_dims(cur_img_tensor.numpy(), 0))
         
    input_data=np.concatenate(rgb_list,axis=0)   

    with torch.no_grad():
        imgDataTensor = torch.from_numpy(input_data).type(torch.FloatTensor).cuda()
        if 'rgb' in architecture_name or 'pose' in architecture_name:
            imgDataTensor = imgDataTensor.view(-1,length,3,imageSize,imageSize).transpose(1,2)
        elif 'flow' in architecture_name:
            imgDataTensor = imgDataTensor.view(-1,length,2,imageSize,imageSize).transpose(1,2)
            
        if 'bert' in architecture_name or 'pooling' in architecture_name:
            output, input_vectors, sequenceOut, maskSample = net(imgDataTensor)
        else:
            output = net(imgDataTensor)
#        outputSoftmax=soft(output)
        result = output.data.cpu().numpy()
        mean_result=np.mean(result,0)
        prediction=np.argmax(mean_result)
        top3 = mean_result.argsort()[::-1][:3]
        
    return prediction, mean_result, top3