Exemple #1
0
def main(path_to_vid,output_path):

    assert torch.cuda.is_available(), \
        "CUDA is not available in your machine"
    # Create networks
    att_model = Sal_based_Attention_module()
    salema_copie = SalEMA()
    # load weight
    att_model = load_model("attention", att_model).cuda()
    Poles  = load_model("poles", salema_copie).cuda()
    Equator  = load_model("equator", salema_copie).cuda()
    
    # loop over videos
    print('dataset contains {} videos'.format(len(os.listdir(path_to_vid))))
    print(os.listdir(path_to_vid))
    model = {"attention":att_model,"poles":Poles,"equator":Equator}
    with torch.no_grad():
        test(path_to_vid, output_path,model)
Exemple #2
0
def main(args):

    if args.dataset == "Hollywood-2" or args.dataset == "UCF-sports":
        dst = os.path.join(args.dst, "{}/testing".format(
            args.dataset))  #Hollywood or UCF-sports
    else:
        dst = os.path.join(
            args.dst, "{}_predictions".format(args.pt_model.replace(".pt",
                                                                    "")))
    print("Output directory {}".format(dst))

    # =================================================
    # ================ Data Loading ===================

    #Expect Error if either validation size or train size is 1
    if args.dataset == "DHF1K":
        #print(args.start)
        #print(args.end)
        print("Commencing inference for dataset {}".format(args.dataset))
        dataset = DHF1K_frames(root_path=args.src,
                               load_gt=False,
                               starting_video=int(args.start),
                               number_of_videos=int(args.end),
                               clip_length=CLIP_LENGTH,
                               split=None,
                               resolution=frame_size)
        #add a parameter node = training or validation

    elif args.dataset == "Hollywood-2" or args.dataset == "UCF-sports":
        print("Commencing inference for dataset {}".format(args.dataset))
        dataset = Hollywood_frames(root_path=args.src,
                                   clip_length=CLIP_LENGTH,
                                   resolution=frame_size)
        video_name_list = dataset.video_names(
        )  #match an index to the sample video name
    elif args.dataset == "DAVIS" or args.dataset == "other":
        print("Commencing inference for dataset {}".format(args.dataset))
        dataset = DAVIS_frames(root_path=args.src,
                               clip_length=CLIP_LENGTH,
                               resolution=frame_size)
        video_name_list = dataset.video_names(
        )  #match an index to the sample video name

    print("Size of test set is {}".format(len(dataset)))

    loader = data.DataLoader(dataset, **params)

    # =================================================
    # ================= Load Model ====================

    # Using same kernel size as they do in the DHF1K paper
    # Amaia uses default hidden size 128
    # input size is 1 since we have grayscale images
    if 'SalCLSTM30' in args.pt_model:

        model = SalCLSTM.SalCLSTM30(seed_init=65, freeze=False, residual=False)

        load_model(args.pt_model, model)
        print("Pre-trained model SalCLSTM30 loaded succesfully")

        TEMPORAL = True

    elif 'SalGAN' in args.pt_model:

        model = SalCLSTM.SalGAN()

        load_model(args.pt_model, model)
        print("Pre-trained model tuned SalGAN loaded succesfully")

        TEMPORAL = False

    elif "EMA" in args.pt_model:
        if args.double_ema:
            model = SalEMA.SalEMA2(alpha=args.alpha,
                                   ema_loc_1=EMA_LOC,
                                   ema_loc_2=EMA_LOC_2)
        else:
            model = SalEMA.SalEMA(alpha=args.alpha,
                                  residual=args.residual,
                                  dropout=args.dropout,
                                  ema_loc=EMA_LOC)

        load_model(args.pt_model, model)
        print("Pre-trained model {} loaded succesfully".format(args.pt_model))
        if args.residual:
            print("Residual connection is included.")

        TEMPORAL = True
        print("Alpha tuned to {}".format(model.alpha))

    elif args.pt_model == 'model_weights/salgan_salicon.pt':

        if EMA_LOC == None:
            model = SalCLSTM.SalGAN()
            TEMPORAL = False
            print("Pre-trained model SalBCE loaded succesfully.")
        else:
            model = SalEMA.SalEMA(alpha=args.alpha, ema_loc=EMA_LOC)
            TEMPORAL = True
            print(
                "Pre-trained model SalBCE loaded succesfully. EMA inference will commence soon."
            )

        model.salgan.load_state_dict(torch.load(args.pt_model)['state_dict'])

    elif args.pt_model == '/imatge/lpanagiotis/work/SalCLSTM/src/model_weights/gen_model.pt':
        model = SalCLSTM.SalGAN()
        model.salgan.load_state_dict(torch.load(args.pt_model))
        print("Pre-trained model vanilla SalGAN loaded succesfully")

        TEMPORAL = False
    else:
        print(
            "Your model was not recognized, check the name of the model and try again."
        )
        exit()

    #model = nn.DataParallel(model).cuda()
    dtype = torch.FloatTensor
    if args.use_gpu:
        assert torch.cuda.is_available(), \
            "CUDA is not available in your machine"
        cudnn.benchmark = False  #Would cause overhead during inference.
        model = model.cuda()
        dtype = torch.cuda.FloatTensor
    # ==================================================
    # ================== Inference =====================

    if not os.path.exists(dst):
        os.mkdir(dst)
    else:
        print(
            "Be warned, you are about to write on an existing folder {}. If this is not intentional cancel now."
            .format(dst))

    # switch to evaluate mode
    model.eval()

    for i, video in enumerate(loader):

        count = 0
        state = None  # Initially no hidden state

        if args.dataset == "DHF1K":

            video_dst = os.path.join(dst, str(int(args.start) + i).zfill(4))
            if not os.path.exists(video_dst):
                os.mkdir(video_dst)

            for j, (clip, _) in enumerate(video):
                clip = Variable(clip.type(dtype).transpose(0, 1),
                                requires_grad=False)
                if args.double_ema:
                    if state == None:
                        state = (None, None)
                    for idx in range(clip.size()[0]):
                        # Compute output
                        state, saliency_map = model.forward(
                            input_=clip[idx],
                            prev_state_1=state[0],
                            prev_state_2=state[1])

                        saliency_map = saliency_map.squeeze(
                            0)  # Target is 3 dimensional (grayscale image)

                        post_process_saliency_map = (
                            saliency_map - torch.min(saliency_map)
                        ) / (torch.max(saliency_map) - torch.min(saliency_map))
                        utils.save_image(
                            post_process_saliency_map,
                            os.path.join(video_dst,
                                         "{}.png".format(str(count).zfill(4))))

                else:
                    for idx in range(clip.size()[0]):
                        # Compute output
                        if TEMPORAL:
                            #import time
                            #start = time.time()
                            state, saliency_map = model.forward(
                                input_=clip[idx], prev_state=state)
                            #print("Inference time of 1 frame is: {}".format(start-time.time()))
                            #exit()
                        else:
                            saliency_map = model.forward(input_=clip[idx])

                        count += 1
                        saliency_map = saliency_map.squeeze(0)

                        post_process_saliency_map = (
                            saliency_map - torch.min(saliency_map)
                        ) / (torch.max(saliency_map) - torch.min(saliency_map))
                        utils.save_image(
                            post_process_saliency_map,
                            os.path.join(video_dst,
                                         "{}.png".format(str(count).zfill(4))))

                if TEMPORAL:
                    state = repackage_hidden(state)
            print("Video {} done".format(i + int((args.start))))

        elif args.dataset == "Hollywood-2" or args.dataset == "UCF-sports":

            video_dst = os.path.join(
                dst, video_name_list[i],
                '{}_predictions'.format(args.pt_model.replace(".pt", "")))
            print("Destination: {}".format(video_dst))
            if not os.path.exists(video_dst):
                os.mkdir(video_dst)

            for j, (clip, _) in enumerate(video):
                clip = Variable(clip.type(dtype).transpose(0, 1),
                                requires_grad=False)
                for idx in range(clip.size()[0]):
                    # Compute output
                    if TEMPORAL:
                        state, saliency_map = model.forward(input_=clip[idx],
                                                            prev_state=state)
                    else:
                        saliency_map = model.forward(input_=clip[idx])

                    count += 1
                    saliency_map = saliency_map.squeeze(0)

                    post_process_saliency_map = (
                        saliency_map - torch.min(saliency_map)) / (
                            torch.max(saliency_map) - torch.min(saliency_map))
                    utils.save_image(
                        post_process_saliency_map,
                        os.path.join(
                            video_dst,
                            "{}{}.png".format(video_name_list[i][:-1],
                                              str(count).zfill(5))))
                    if count == 1:
                        print(
                            "The final destination is {}. Cancel now if this is incorrect"
                            .format(
                                os.path.join(
                                    video_dst,
                                    "{}{}.png".format(video_name_list[i][:-1],
                                                      str(count).zfill(5)))))

                if TEMPORAL:
                    state = repackage_hidden(state)
            print("Video {} done".format(i + int(args.start)))

        elif args.dataset == "DAVIS" or args.dataset == "other":

            video_dst = os.path.join(dst, video_name_list[i])
            # if "shooting" in video_dst:
            #     # CUDA error: out of memory is encountered whenever inference reaches that vid.
            #     continue
            print("Destination: {}".format(video_dst))
            if not os.path.exists(video_dst):
                os.mkdir(video_dst)

            for j, (clip, _) in enumerate(video):
                clip = Variable(clip.type(dtype).transpose(0, 1),
                                requires_grad=False)

                for idx in range(clip.size()[0]):
                    # Compute output
                    if TEMPORAL:
                        state, saliency_map = model.forward(input_=clip[idx],
                                                            prev_state=state)
                    else:
                        saliency_map = model.forward(input_=clip[idx])

                    count += 1
                    saliency_map = saliency_map.squeeze(0)

                    post_process_saliency_map = (
                        saliency_map - torch.min(saliency_map)) / (
                            torch.max(saliency_map) - torch.min(saliency_map))
                    utils.save_image(
                        post_process_saliency_map,
                        os.path.join(video_dst,
                                     "{}.jpg".format(str(count).zfill(5))))
                    if count == 1:
                        print(
                            "The final destination is {}. Cancel now if this is incorrect"
                            .format(
                                os.path.join(
                                    video_dst,
                                    "{}.jpg".format(str(count).zfill(5)))))

                if TEMPORAL:
                    state = repackage_hidden(state)
            print("Video {} done".format(i + int(args.start)))
Exemple #3
0
def main(args, params = params):

    # =================================================
    # ================ Data Loading ===================

    #Expect Error if either validation size or train size is 1
    if args.dataset == "DHF1K":
        print("Commencing training on dataset {}".format(args.dataset))
        train_set = DHF1K_frames(
            root_path = args.src,
            load_gt = True,
            number_of_videos = int(args.end),
            starting_video = int(args.start),
            clip_length = clip_length,
            resolution = frame_size,
            val_perc = args.val_perc,
            split = "train")
        print("Size of train set is {}".format(len(train_set)))
        train_loader = data.DataLoader(train_set, **params)

        if args.val_perc > 0:
            val_set = DHF1K_frames(
                root_path = args.src,
                load_gt = True,
                number_of_videos = int(args.end),
                starting_video = int(args.start),
                clip_length = clip_length,
                resolution = frame_size,
                val_perc = args.val_perc,
                split = "validation")
            print("Size of validation set is {}".format(len(val_set)))
            val_loader = data.DataLoader(val_set, **params)

    elif args.dataset == "Hollywood-2" or args.dataset == "UCF-sports":
        print("Commencing training on dataset {}".format(args.dataset))
        train_set = Hollywood_frames(
            root_path = "/imatge/lpanagiotis/work/{}/training".format(args.dataset),
            #root_path = "/home/linardosHollywood-2/training",
            clip_length = clip_length,
            resolution = frame_size,
            load_gt = True)
        video_name_list = train_set.video_names() #match an index to the sample video name
        train_loader = data.DataLoader(train_set, **params)

    else:
        print('Your model was not recognized. Check the name again.')
        exit()
    # =================================================
    # ================ Define Model ===================

    # The seed pertains to initializing the weights with a normal distribution
    # Using brute force for 100 seeds I found the number 65 to provide a good starting point (one that looks close to a saliency map predicted by the original SalGAN)
    temporal = True
    if 'CLSTM56' in args.new_model:
        model = SalGANmore.SalGANplus(seed_init=65, freeze=args.thaw)
        print("Initialized {}".format(args.new_model))
    elif 'CLSTM30' in args.new_model:
        model = SalGANmore.SalCLSTM30(seed_init=65, residual=args.residual, freeze=args.thaw)
        print("Initialized {}".format(args.new_model))
    elif 'SalBCE' in args.new_model:
        model = SalGANmore.SalGAN()
        print("Initialized {}".format(args.new_model))
        temporal = False
    elif 'EMA' in args.new_model:
        if args.double_ema != False:
            model = SalEMA.SalEMA2(alpha=0.3, ema_loc_1=args.ema_loc, ema_loc_2=args.double_ema)
            print("Initialized {}".format(args.new_model))
        else:
            model = SalEMA.SalEMA(alpha=args.alpha, residual=args.residual, dropout= args.dropout, ema_loc=args.ema_loc)
            print("Initialized {} with residual set to {} and dropout set to {}".format(args.new_model, args.residual, args.dropout))
    else:
        print("Your model was not recognized, check the name of the model and try again.")
        exit()
    #criterion = nn.BCEWithLogitsLoss() # This loss combines a Sigmoid layer and the BCELoss in one single class
    criterion = nn.BCELoss()
    #optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=momentum, weight_decay=weight_decay)
    #optimizer = torch.optim.RMSprop(model.parameters(), args.lr, alpha=0.99, eps=1e-08, momentum=momentum, weight_decay=weight_decay)
    #start

    if args.thaw:
        # Load only the unfrozen part to the optimizer

        if args.new_model == 'SalGANplus.pt':
            optimizer = torch.optim.Adam([{'params': model.Gates.parameters()},{'params': model.final_convs.parameters()}], args.lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=weight_decay)

        elif 'SalCLSTM30' in args.new_model:
            optimizer = torch.optim.Adam([{'params': model.Gates.parameters()}], args.lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=weight_decay)

    else:
        #optimizer = torch.optim.Adam(model.parameters(), args.lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=weight_decay)
        if args.alpha is None:
            optimizer = torch.optim.Adam([
            {'params':model.salgan.parameters() , 'lr': args.lr, 'weight_decay':weight_decay},
            {'params':model.alpha, 'lr': 0.1}])
        else:
            optimizer = torch.optim.Adam([
            {'params':model.salgan.parameters() , 'lr': args.lr, 'weight_decay':weight_decay}])
                    
        if LEARN_ALPHA_ONLY:
            optimizer = torch.optim.Adam([{'params':[model.alpha]}], 0.1)



    if args.pt_model == None:
        # In truth it's not None, we default to SalGAN or SalBCE (JuanJo's)weights
        # By setting strict to False we allow the model to load only the matching layers' weights
        if SALGAN_WEIGHTS == 'model_weights/gen_model.pt':
            model.salgan.load_state_dict(torch.load(SALGAN_WEIGHTS), strict=False)
        else:
            model.salgan.load_state_dict(torch.load(SALGAN_WEIGHTS)['state_dict'], strict=False)


        start_epoch = 1

    else:
        # Load an entire pretrained model
        checkpoint = load_weights(model, args.pt_model)
        model.load_state_dict(checkpoint, strict=False)
        start_epoch = torch.load(args.pt_model, map_location='cpu')['epoch']
        #optimizer.load_state_dict(torch.load(args.pt_model, map_location='cpu')['optimizer'])

        print("Model loaded, commencing training from epoch {}".format(start_epoch))

    dtype = torch.FloatTensor
    if args.use_gpu == 'parallel' or args.use_gpu == 'gpu':
        assert torch.cuda.is_available(), \
            "CUDA is not available in your machine"

        if args.use_gpu == 'parallel':
            model = nn.DataParallel(model).cuda()
        elif args.use_gpu == 'gpu':
            model = model.cuda()
        dtype = torch.cuda.FloatTensor
        cudnn.benchmark = True #https://discuss.pytorch.org/t/what-does-torch-backends-cudnn-benchmark-do/5936
        criterion = criterion.cuda()
    # =================================================
    # ================== Training =====================


    train_losses = []
    val_losses = []
    starting_time = datetime.datetime.now().replace(microsecond=0)
    print("Training started at : {}".format(starting_time))

    n_iter = 0
    #if "EMA" in args.new_model:
    #    print("Alpha value started at: {}".format(model.alpha))

    for epoch in range(start_epoch, args.epochs+1):

        try:
            #adjust_learning_rate(optimizer, epoch, decay_rate) #Didn't use this after all
            # train for one epoch
            train_loss, n_iter, optimizer = train(train_loader, model, criterion, optimizer, epoch, n_iter, args.use_gpu, args.double_ema, args.thaw, temporal, dtype)

            print("Epoch {}/{} done with train loss {}\n".format(epoch, args.epochs, train_loss))

            if args.val_perc > 0:
                print("Running validation..")
                val_loss = validate(val_loader, model, criterion, epoch, temporal, dtype)
                print("Validation loss: {}".format(val_loss))

            if epoch % plot_every == 0:
                train_losses.append(train_loss.cpu())
                if args.val_perc > 0:
                    val_losses.append(val_loss.cpu())

            torch.save({
                'epoch': epoch + 1,
                'state_dict': model.cpu().state_dict(),
                'optimizer' : optimizer.state_dict()
                }, args.new_model+".pt")

            if args.use_gpu == 'parallel':
                model = nn.DataParallel(model).cuda()
            elif args.use_gpu == 'gpu':
                model = model.cuda()
            else:
                pass

            """
            else:

                print("Training on whole set")
                train_loss, n_iter, optimizer = train(whole_loader, model, criterion, optimizer, epoch, n_iter)
                print("Epoch {}/{} done with train loss {}".format(epoch, args.epochs, train_loss))
            """

        except RuntimeError:
            print("A memory error was encountered. Further training aborted.")
            epoch = epoch - 1
            break

    print("Training of {} started at {} and finished at : {} \n Now saving..".format(args.new_model, starting_time, datetime.datetime.now().replace(microsecond=0)))
    #if "EMA" in args.new_model:
    #    print("Alpha value tuned to: {}".format(model.alpha))
    # ===================== #
    # ======  Saving ====== #

    # If I try saving in regular intervals I have to move the model to CPU and back to GPU.
    torch.save({
        'epoch': epoch + 1,
        'state_dict': model.cpu().state_dict(),
        'optimizer' : optimizer.state_dict()
        }, args.new_model+".pt")

    """
    hyperparameters = {
        'momentum' : momentum,
        'weight_decay' : weight_decay,
        'args.lr' : learning_rate,
        'decay_rate' : decay_rate,
        'args.epochs' : args.epochs,
        'batch_size' : batch_size
    }
    """

    if args.val_perc > 0:
        to_plot = {
            'epoch_ticks': list(range(start_epoch, args.epochs+1, plot_every)),
            'train_losses': train_losses,
            'val_losses': val_losses
            }
        with open('to_plot.pkl', 'wb') as handle:
            pickle.dump(to_plot, handle, protocol=pickle.HIGHEST_PROTOCOL)
def main(args, params = params):

    # =================================================
    # ================ Data Loading ===================

    #Expect Error if either validation size or train size is 1
    if args.dataset == "Poles" :
        print("Commencing training on dataset {}".format(args.dataset))
        train_set = Poles(
            root_path = args.src,
            load_gt = True,
            number_of_videos = int(args.end),
            starting_video = int(args.start),
            clip_length = clip_length,
            resolution = frame_size,
            val_perc = args.val_perc,
            split = "train")
        print("Size of train set is {}".format(len(train_set)))
        train_loader = data.DataLoader(train_set, **params)

        if args.val_perc > 0:
            val_set = Poles(
                root_path = args.src,
                load_gt = True,
                number_of_videos = int(args.end),
                starting_video = int(args.start),
                clip_length = clip_length,
                resolution = frame_size,
                val_perc = args.val_perc,
                split = "validation")
            print("Size of validation set is {}".format(len(val_set)))
            val_loader = data.DataLoader(val_set, **params)

    if  args.dataset == "Equator":
        print("Commencing training on dataset {}".format(args.dataset))
        train_set = Equator(
            root_path = args.src,
            load_gt = True,
            number_of_videos = int(args.end),
            starting_video = int(args.start),
            clip_length = clip_length,
            resolution = frame_size,
            val_perc = args.val_perc,
            split = "train")
        print("Size of train set is {}".format(len(train_set)))
        train_loader = data.DataLoader(train_set, **params)

        if args.val_perc > 0:
            val_set = Equator(
                root_path = args.src,
                load_gt = True,
                number_of_videos = int(args.end),
                starting_video = int(args.start),
                clip_length = clip_length,
                resolution = frame_size,
                val_perc = args.val_perc,
                split = "validation")
            print("Size of validation set is {}".format(len(val_set)))
            val_loader = data.DataLoader(val_set, **params)


    else:
        print('Your model was not recognized. Check the name again.')
        exit()
    # =================================================
    # ================ Define Model ===================

    # The seed pertains to initializing the weights with a normal distribution
    # Using brute force for 100 seeds I found the number 65 to provide a good starting point (one that looks close to a saliency map predicted by the original SalGAN)
    temporal = True

    elif 'EMA' in args.new_model:

        if 'Poles' in args.new_model:
            model = SalEMA.Poles_EMA(alpha=None, ema_loc=args.ema_loc)
            print("Initialized {} with residual set to {} and dropout set to {}".format(args.new_model))
        elif 'Equator' in args.new_model:
            model = SalEMA.Equator_EMA(alpha=None, ema_loc=args.ema_loc)
            print("Initialized {} with residual set to {} and dropout set to {}".format(args.new_model)) 
Exemple #5
0
def main():
    ''' read frames in path_indata and generate frame-wise saliency maps in path_output '''
    # optional two command-line arguments
    args = parse_args()
    config_path = args.config
    config = edict(yaml.load(open(config_path)))
    # path_indata = '/data2/yuanx/QoEData/sailency-models/TASED-Net/example/'
    # path_output = '/data2/yuanx/QoEData/sailency-models/TASED-Net/output/'
    # model_path = '/data2/yuanx/QoEData/sailency-models/TASED-Net/models/'
    path_output = config.output_path
    model_path = config.model_path
    len_temporal = 10

    if not os.path.isdir(path_output):
        os.makedirs(path_output)

    model = SalEMA(alpha = config.model.alpha, \
                    residual=config.model.residual, \
                    dropout = config.model.dropout, \
                    ema_loc=config.model.ema_loc)
    model = load_model(model_path, model)
    torch.backends.cudnn.benchmark = False
    model = model.cuda()
    model.eval()
    list_video = [
        os.path.join(config.data.base_path, v) for v in config.data.video_list
    ]
    # list_indata.sort()
    for vname in list_video:
        print('processing ' + vname)
        # list_frames = [f for f in os.listdir(os.path.join(path_indata, vname)) if os.path.isfile(os.path.join(path_indata, vname, f))]
        # list_frames.sort()
        vname = vname + config.data.quailty_index + '.' + config.data.video_suffix
        capture = cv2.VideoCapture(vname)
        read_flag, img = capture.read()
        i = 0
        path_outdata = os.path.join(path_output,
                                    vname.split('/')[-1].split('.')[0])
        encoded_vid_path = os.path.join(path_outdata, "sailency.mp4")

        if not os.path.isdir(path_outdata):
            os.makedirs(path_outdata)

        # for i in range(len(list_frames)):
        state = None
        while (read_flag):
            # process this clip
            state = process(model, path_outdata, img, i, state)
            if (i + 1) % len_temporal == 0:
                state = repackage_hidden(state)
            read_flag, img = capture.read()
            i += 1

        capture.release()
        encoding_result = subprocess.run([
            "ffmpeg", "-y", "-start_number", '0', "-i",
            f"{path_outdata}/%06d.jpg", "-loglevel", "error", "-vcodec",
            "libx264", "-pix_fmt", "yuv420p", "-crf", "23", encoded_vid_path
        ],
                                         stdout=subprocess.PIPE,
                                         stderr=subprocess.PIPE,
                                         universal_newlines=True)

        if encoding_result.returncode != 0:
            # Encoding failed
            print("ENCODING FAILED")
            print(encoding_result.stdout)
            print(encoding_result.stderr)
            # exit()
            continue
def main(args):


    dst = os.path.join(args.dst, "{}_predictions".format(args.pt_model.replace(".pt", "")))
    print("Output directory {}".format(dst))

    # =================================================
    # ================ Data Loading ===================

    #Expect Error if either validation size or train size is 1

    if args.dataset == "Equator" or args.dataset == "Poles" or args.dataset == "other" :
        print("Commencing inference for dataset {}".format(args.dataset))
        dataset = TEST(
            root_path = args.src,
            clip_length = CLIP_LENGTH,
            resolution = frame_size)
        video_name_list = dataset.video_names() #match an index to the sample video name
    else :
        print('dataset not defined')
        exit()


    print("Size of test set is {}".format(len(dataset)))

    loader = data.DataLoader(dataset, **params)

    # =================================================
    # ================= Load Model ====================

    # Using same kernel size as they do in the DHF1K paper
    # Amaia uses default hidden size 128
    # input size is 1 since we have grayscale images

    if "EMA" in args.pt_model:
        if "poles" in args.pt_model:
            model = SalEMA.Poles_EMA(alpha=args.alpha, ema_loc=EMA_LOC)
        elif "equator" in args.pt_model:
            model = SalEMA.Equator_EAM(alpha=args.alpha, ema_loc=EMA_LOC)
        
        load_model(args.pt_model, model)
        print("Pre-trained model {} loaded succesfully".format(args.pt_model))

        TEMPORAL = True
        print("Alpha tuned to {}".format(model.alpha))

    else:
        print("Your model was not recognized not (pole or equator), check the name of the model and try again.")
        exit()

    dtype = torch.FloatTensor
    if args.use_gpu:
        assert torch.cuda.is_available(), \
            "CUDA is not available in your machine"
        cudnn.benchmark = True 
        model = model.cuda()
        dtype = torch.cuda.FloatTensor


    # ================== Inference =====================

    if not os.path.exists(dst):
        os.mkdir(dst)
    else:
        print(" you are about to write on an existing folder {}. If this is not intentional cancel now.".format(dst))

    # switch to evaluate mode
    model.eval()

    for i, video in enumerate(loader):

        count = 0
        state = None # Initially no hidden state

        elif args.dataset == "Poles" or args.dataset == "Equator":

            video_dst = os.path.join(dst, video_name_list[i])
            # if "shooting" in video_dst:
            #     # CUDA error: out of memory is encountered whenever inference reaches that vid.
            #     continue
            print("Destination: {}".format(video_dst))
            if not os.path.exists(video_dst):
                os.mkdir(video_dst)

            for j, (clip, _) in enumerate(video):
                clip = Variable(clip.type(dtype).transpose(0,1), requires_grad=False)

                for idx in range(clip.size()[0]):
                    # Compute output
                    if TEMPORAL:
                        state, saliency_map = model.forward(input_ = clip[idx], prev_state = state)
                    else:
                        saliency_map = model.forward(input_ = clip[idx])

                    
                    saliency_map = saliency_map.squeeze(0)
    
                    post_process_saliency_map = (saliency_map-torch.min(saliency_map))/(torch.max(saliency_map)-torch.min(saliency_map))
                    utils.save_image(post_process_saliency_map, os.path.join(video_dst, "{}.png".format(str(count).zfill(4))))
                    if count == 0:
                        print("The final destination is {}".format(os.path.join(video_dst)))
                    count+=1
                if TEMPORAL:
                    state = repackage_hidden(state)
            print("Video {} done".format(i+int(args.start)))