예제 #1
0
def main(DATASET,
         LABELS,
         CLASS_IDS,
         BATCH_SIZE,
         ANNOTATION_FILE,
         SEQ_SIZE=16,
         STEP=16,
         nstrokes=-1,
         N_EPOCHS=25,
         base_name=""):
    '''
    Extract sequence features from AutoEncoder.
    
    Parameters:
    -----------
    DATASET : str
        path to the video dataset
    LABELS : str
        path containing stroke labels
    CLASS_IDS : str
        path to txt file defining classes, similar to THUMOS
    BATCH_SIZE : int
        size for batch of clips
    SEQ_SIZE : int
        no. of frames in a clip (min. 16 for 3D CNN extraction)
    STEP : int
        stride for next example. If SEQ_SIZE=16, STEP=8, use frames (0, 15), (8, 23) ...
    partition : str
        'all' / 'train' / 'test' / 'val' : Videos to be considered
    nstrokes : int
        partial extraction of features (do not execute for entire dataset)
    
    Returns:
    --------
    trajectories, stroke_names
    
    '''
    ###########################################################################
    # seed everything

    if not os.path.isdir(base_name):
        os.makedirs(base_name)

    # Read the strokes
    # Divide the highlight dataset files into training, validation and test sets
    train_lst, val_lst, test_lst = autoenc_utils.split_dataset_files(DATASET)
    print("No. of training videos : {}".format(len(train_lst)))

    ###########################################################################
    # Create a Dataset
    # Clip level transform. Use this with framewiseTransform flag turned off
    clip_transform = transforms.Compose([
        videotransforms.CenterCrop(224),
        videotransforms.ToPILClip(),
        videotransforms.Resize((112, 112)),
        #                                         videotransforms.RandomCrop(112),
        videotransforms.ToTensor(),
        videotransforms.Normalize(),
        #videotransforms.RandomHorizontalFlip(),\
    ])
    # or use CricketStrokesFlowDataset
    train_dataset = CricketStrokesDataset(train_lst,
                                          DATASET,
                                          LABELS,
                                          CLASS_IDS,
                                          frames_per_clip=SEQ_SIZE,
                                          step_between_clips=STEP,
                                          train=True,
                                          framewiseTransform=False,
                                          transform=clip_transform)
    # or use CricketStrokesFlowDataset
    val_dataset = CricketStrokesDataset(val_lst,
                                        DATASET,
                                        LABELS,
                                        CLASS_IDS,
                                        frames_per_clip=SEQ_SIZE,
                                        step_between_clips=STEP,
                                        train=False,
                                        framewiseTransform=False,
                                        transform=clip_transform)

    train_loader = DataLoader(dataset=train_dataset,
                              batch_size=BATCH_SIZE,
                              shuffle=True)

    val_loader = DataLoader(dataset=val_dataset,
                            batch_size=BATCH_SIZE,
                            shuffle=False)

    data_loaders = {"train": train_loader, "test": val_loader}

    ###########################################################################

    labs_keys, labs_values = attn_utils.get_cluster_labels(ANNOTATION_FILE)

    num_classes = len(list(set(labs_values)))

    ###########################################################################
    # load model and set loss function
    model = conv_encdec_model.ConvVAE()

    model = model.to(device)
    #    # load checkpoint:

    # Setup the loss fxn
    criterion = nn.MSELoss()

    #    # Layers to finetune. Last layer should be displayed
    print("Params to learn:")
    params_to_update = []
    for name, param in model.named_parameters():
        if param.requires_grad == True:
            params_to_update.append(param)
            print("ConvVAE : {}".format(name))

    # Observe that all parameters are being optimized


#    optimizer_ft = torch.optim.Adam(params_to_update, lr=0.001)
    optimizer_ft = torch.optim.SGD(params_to_update, lr=0.01, momentum=0.9)

    # Decay LR by a factor of 0.1 every 7 epochs
    lr_scheduler = StepLR(optimizer_ft, step_size=10, gamma=0.1)

    #    # Observe that all parameters are being optimized
    #    optimizer_ft = optim.SGD(params_to_update, lr=0.001, momentum=0.9)

    ###########################################################################
    # Training the model
    start = time.time()

    model = train_model(model,
                        data_loaders,
                        criterion,
                        optimizer_ft,
                        lr_scheduler,
                        labs_keys,
                        labs_values,
                        seq=8,
                        num_epochs=N_EPOCHS)

    end = time.time()

    print("Total Execution time for {} epoch : {}".format(
        N_EPOCHS, (end - start)))
    ###########################################################################
    # Save only the model params
    model_name = os.path.join(base_name,
                              "conv_vae_ep" + str(N_EPOCHS) + "_SGD.pt")

    #    torch.save(model.state_dict(), model_name)
    #    print("Model saved to disk... : {}".format(model_name))    # Load model checkpoints

    # Loading the saved model
    model_name = os.path.join(base_name,
                              "conv_vae_ep" + str(N_EPOCHS) + "_SGD.pt")
    if os.path.isfile(model_name):
        model.load_state_dict(torch.load(model_name))
        print("Loading ConvVAE weights... : {}".format(model_name))

    ###########################################################################

    print("Writing prediction dictionary....")
    #    pred_out_dict = predict(encoder, decoder, data_loaders, criterion, labs_keys,
    #                            labs_values, phase='test')
    if not os.path.isfile(os.path.join(base_name, "conv_vae_train.pkl")):
        if not os.path.exists(base_name):
            os.makedirs(base_name)
        feats_dict, stroke_names = extract_attn_feats(model, DATASET, LABELS,
                                                      CLASS_IDS, BATCH_SIZE,
                                                      SEQ_SIZE, 16, 'train',
                                                      -1, base_name)
        with open(os.path.join(base_name, "conv_vae_train.pkl"), "wb") as fp:
            pickle.dump(feats_dict, fp)
        with open(os.path.join(base_name, "conv_vae_snames_train.pkl"),
                  "wb") as fp:
            pickle.dump(stroke_names, fp)
    if not os.path.isfile(os.path.join(base_name, "conv_vae_val.pkl")):
        if not os.path.exists(base_name):
            os.makedirs(base_name)
        feats_dict, stroke_names = extract_attn_feats(model, DATASET, LABELS,
                                                      CLASS_IDS, BATCH_SIZE,
                                                      SEQ_SIZE, 16, 'val', -1,
                                                      base_name)
        with open(os.path.join(base_name, "conv_vae_val.pkl"), "wb") as fp:
            pickle.dump(feats_dict, fp)
        with open(os.path.join(base_name, "conv_vae_snames_val.pkl"),
                  "wb") as fp:
            pickle.dump(stroke_names, fp)
    if not os.path.isfile(os.path.join(base_name, "conv_vae_test.pkl")):
        if not os.path.exists(base_name):
            os.makedirs(base_name)
        feats_dict, stroke_names = extract_attn_feats(model, DATASET, LABELS,
                                                      CLASS_IDS, BATCH_SIZE,
                                                      SEQ_SIZE, 16, 'test', -1,
                                                      base_name)
        with open(os.path.join(base_name, "conv_vae_test.pkl"), "wb") as fp:
            pickle.dump(feats_dict, fp)
        with open(os.path.join(base_name, "conv_vae_snames_test.pkl"),
                  "wb") as fp:
            pickle.dump(stroke_names, fp)

    print("#Parameters ConvVAE : {} ".format(
        autoenc_utils.count_parameters(model)))

    return model
def extract_sequence_feats(model_path, DATASET, LABELS, CLASS_IDS, BATCH_SIZE, 
                           INPUT_SIZE, HIDDEN_SIZE, NUM_LAYERS, SEQ_SIZE=16, STEP=16, 
                           partition='all', nstrokes=-1):
    '''
    Extract sequence features from AutoEncoder.
    
    Parameters:
    -----------
    model_path : str
        relative path to the checkpoint file for Autoencoder
    DATASET : str
        path to the video dataset
    LABELS : str
        path containing stroke labels
    BATCH_SIZE : int
        size for batch of clips
    INPUT_SIZE : int
        size of the extracted feature vector (output of ResNet). Input size of
        Autoencoder.
    HIDDEN_SIZE : int
        hidden size of autoencoder. #Parameters of autoencoder depend on it.
    NUM_LAYERS : int
        No. of GRU layers in the Autoencoder
    SEQ_SIZE : int
        no. of frames in a clip
    STEP : int
        stride for next example. If SEQ_SIZE=16, STEP=8, use frames (0, 15), (8, 23) ...
    partition : str
        'all' / 'train' / 'test' / 'val' : Videos to be considered
    nstrokes : int
        partial extraction of features (do not execute for entire dataset)
    
    Returns:
    --------
    trajectories, stroke_names
    
    '''
    
    ###########################################################################
    # Read the strokes 
    # Divide the highlight dataset files into training, validation and test sets
    train_lst, val_lst, test_lst = autoenc_utils.split_dataset_files(DATASET)
    print("No. of training videos : {}".format(len(train_lst)))
    
    #####################################################################
    
    if partition == 'all':
        partition_lst = train_lst
        partition_lst.extend(val_lst)
        partition_lst.extend(test_lst)
    elif partition == 'train':
        partition_lst = train_lst
    elif partition == 'val':
        partition_lst = val_lst
    elif partition == 'test':
        partition_lst = test_lst
        
    ###########################################################################
    # Create a Dataset    
    # Frame-wise transform
    clip_transform = transforms.Compose([transforms.ToPILImage(),
                                         transforms.Resize((224, 224)),
                                         transforms.ToTensor(),
                                         transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                                                              std=[0.229, 0.224, 0.225]),]) 
    
    # Clip level transform. Use this with framewiseTransform flag turned off
#    clip_transform = transforms.Compose([videotransforms.ToPILClip(), 
#                                         videotransforms.Resize((112, 112)),
##                                         videotransforms.RandomCrop(112), 
#                                         videotransforms.ToTensor(), 
#                                         videotransforms.Normalize(),
#                                        #videotransforms.RandomHorizontalFlip(),\
#                                        ])

    part_dataset = CricketStrokesDataset(partition_lst, DATASET, LABELS, CLASS_IDS, 
                                         frames_per_clip=SEQ_SIZE, step_between_clips=STEP,
                                         train=True, framewiseTransform=True,
                                         transform=clip_transform)
    
    data_loader = DataLoader(dataset=part_dataset, batch_size=BATCH_SIZE, shuffle=False)

    ###########################################################################
    # Create a model and load the weights of AutoEncoder
    model = autoenc.AutoEncoderRNN(INPUT_SIZE, HIDDEN_SIZE, NUM_LAYERS)
    if model_path is not None:
        print("Loading model ...")
        checkpoint = torch.load(model_path)
        model.load_state_dict(checkpoint['model_state_dict'])
    
    model = model.to(device)

    ###########################################################################
    # Validate / Evaluate
    model.eval()
    stroke_names = []
    trajectories, stroke_traj = [], []
    num_strokes = 0
    extractor = Img2Vec()
#    extractor = Clip2Vec()
    #INPUT_SIZE = extractor.layer_output_size
    prev_stroke = None
    
    print("Total Batches : {} :: BATCH_SIZE : {}".format(data_loader.__len__(), BATCH_SIZE))
#    assert BATCH_SIZE % SEQ_SIZE == 0, "BATCH_SIZE should be a multiple of SEQ_SIZE"
    for bno, (inputs, vid_path, stroke, _) in enumerate(data_loader):
        # get video clips (B, SL, C, H, W)
        print("Batch No : {}".format(bno))
        # Extract spatial features using 2D ResNet
        if isinstance(extractor, Img2Vec):
            inputs = torch.stack([extractor.get_vec(x) for x in inputs])
        # Extract spatio-temporal features from clip using 3D ResNet
        else:
            # for SEQ_LEN >= 16
            inputs = inputs.permute(0, 2, 1, 3, 4).float()
            inputs = extractor.get_vec(inputs)
            
        # convert to start frames and end frames from tensors to lists
        stroke = [s.tolist() for s in stroke]
        inputs_lst, batch_stroke_names = autoenc_utils.separate_stroke_tensors(inputs, \
                                                                    vid_path, stroke)
        
        if bno == 0:
            prev_stroke = batch_stroke_names[0]
        
        for enc_idx, enc_input in enumerate(inputs_lst):
            # get no of sequences that can be extracted from enc_input tensor
            nSeqs = enc_input.size(0)
            if prev_stroke != batch_stroke_names[enc_idx]:
                # append old stroke to trajectories
                if len(stroke_traj) > 0:
                    num_strokes += 1
                    trajectories.append(stroke_traj)
                    stroke_names.append(prev_stroke)
                    stroke_traj = []
            
            enc_output = model.encoder(enc_input.to(device))
            enc_output = enc_output.squeeze(axis=1).cpu().data.numpy()
            # convert to [[[stroke1(size 32 each) ... ], [], ...], [ [], ... ]]
            stroke_traj.extend([enc_output[i, :] for i in range(enc_output.shape[0])])
            prev_stroke = batch_stroke_names[enc_idx]
            
                
        if nstrokes >= -1 and num_strokes >= nstrokes:
            break
       
    # for last batch only if extracted for full dataset
    if len(stroke_traj) > 0 and nstrokes < 0:
        trajectories.append(stroke_traj)
        stroke_names.append(batch_stroke_names[-1])
        
    trajectories, stroke_names = autoenc_utils.group_strokewise(trajectories, stroke_names)
    #stroke_vecs, stroke_names =  aggregate_outputs(sequence_outputs, seq_stroke_names)
    #stroke_vecs = [stroke.cpu().data.numpy() for stroke in stroke_vecs]
    
    # save to disk
#    np.save("trajectories.npy", trajectories)
#    with open('stroke_names_val.pkl', 'wb') as fp:
#        pickle.dump(stroke_names, fp)

    # read the files from disk
#    trajectories = np.load("trajectories.npy")
#    with open('stroke_names_val.pkl', 'rb') as fp:
#        stroke_names = pickle.load(fp)

#    print("#Parameters : {}".format(autoenc_utils.count_parameters(model)))
    
    return trajectories, stroke_names
예제 #3
0
def extract_attn_feats(model,
                       DATASET,
                       LABELS,
                       CLASS_IDS,
                       BATCH_SIZE,
                       SEQ_SIZE=16,
                       STEP=16,
                       partition='train',
                       nstrokes=-1,
                       base_name=""):
    '''
    Extract sequence features from AutoEncoder.
    
    Parameters:
    -----------
    encoder, decoder : attn_model.Encoder 
        relative path to the checkpoint file for Autoencoder
    DATASET : str
        path to the video dataset
    LABELS : str
        path containing stroke labels
    BATCH_SIZE : int
        size for batch of clips
    SEQ_SIZE : int
        no. of frames in a clip
    STEP : int
        stride for next example. If SEQ_SIZE=16, STEP=8, use frames (0, 15), (8, 23) ...
    partition : str
        'train' / 'test' / 'val' : Videos to be considered
    nstrokes : int
        partial extraction of features (do not execute for entire dataset)
    base_name : str
        path containing the pickled feature dumps
    
    Returns:
    --------
    features_dictionary, stroke_names
    
    '''

    ###########################################################################
    # Read the strokes
    # Divide the highlight dataset files into training, validation and test sets
    train_lst, val_lst, test_lst = autoenc_utils.split_dataset_files(DATASET)
    print("No. of training videos : {}".format(len(train_lst)))

    #####################################################################

    if partition == 'train':
        partition_lst = train_lst
    elif partition == 'val':
        partition_lst = val_lst
    elif partition == 'test':
        partition_lst = test_lst
    else:
        print("Partition should be : train / val / test")
        return

    ###########################################################################
    # Create a Dataset
    # Clip level transform. Use this with framewiseTransform flag turned off
    clip_transform = transforms.Compose([
        videotransforms.CenterCrop(224),
        videotransforms.ToPILClip(),
        videotransforms.Resize((112, 112)),
        #                                         videotransforms.RandomCrop(112),
        videotransforms.ToTensor(),
        #                                         videotransforms.Normalize(),
        #videotransforms.RandomHorizontalFlip(),\
    ])
    part_dataset = CricketStrokesDataset(partition_lst,
                                         DATASET,
                                         LABELS,
                                         CLASS_IDS,
                                         frames_per_clip=SEQ_SIZE,
                                         step_between_clips=STEP,
                                         train=False,
                                         framewiseTransform=False,
                                         transform=clip_transform)

    data_loader = DataLoader(dataset=part_dataset,
                             batch_size=BATCH_SIZE,
                             shuffle=False)

    ###########################################################################
    # Validate / Evaluate
    model.eval()
    stroke_names = []
    trajectories, stroke_traj = [], []
    num_strokes = 0
    prev_stroke = None
    print("Total Batches : {} :: BATCH_SIZE : {}".format(
        data_loader.__len__(), BATCH_SIZE))
    ###########################################################################
    for bno, (inputs, vid_path, stroke, labels) in enumerate(data_loader):
        # inputs of shape BATCH x SEQ_LEN x FEATURE_DIM
        inputs = inputs.permute(0, 2, 1, 3, 4).float()
        inputs = inputs.to(device)
        #        print("Batch No : {} / {}".format(bno, len(data_loader)))
        # forward
        # track history if only in train
        with torch.set_grad_enabled(False):

            recon_x, mu, logvar = model(inputs)

#            dec_out_lst = []
#            dec_out_lst.append(out_mu)
#
#            outputs = torch.stack(dec_out_lst, dim=1)

# convert to start frames and end frames from tensors to lists
        stroke = [s.tolist() for s in stroke]
        # outputs are the reconstructed features. Use compressed enc_out values(maybe wtd.).
        inputs_lst, batch_stroke_names = autoenc_utils.separate_stroke_tensors(mu, \
                                                                    vid_path, stroke)

        # for sequence of features from batch segregated extracted features.
        if bno == 0:
            prev_stroke = batch_stroke_names[0]

        for enc_idx, enc_input in enumerate(inputs_lst):
            # get no of sequences that can be extracted from enc_input tensor
            nSeqs = enc_input.size(0)
            if prev_stroke != batch_stroke_names[enc_idx]:
                # append old stroke to trajectories
                if len(stroke_traj) > 0:
                    num_strokes += 1
                    trajectories.append(stroke_traj)
                    stroke_names.append(prev_stroke)
                    stroke_traj = []

#            enc_output = model.encoder(enc_input.to(device))
#            enc_output = enc_output.squeeze(axis=1).cpu().data.numpy()
            enc_output = enc_input.cpu().data.numpy()

            # convert to [[[stroke1(size 32 each) ... ], [], ...], [ [], ... ]]
            stroke_traj.extend(
                [enc_output[i, :] for i in range(enc_output.shape[0])])
            prev_stroke = batch_stroke_names[enc_idx]

        if nstrokes > -1 and num_strokes >= nstrokes:
            break

    # for last batch only if extracted for full dataset
    if len(stroke_traj) > 0 and nstrokes < 0:
        trajectories.append(stroke_traj)
        stroke_names.append(batch_stroke_names[-1])

    # convert to dictionary of features with keys as stroke names(with ext).
    features = {}
    for i, t in enumerate(trajectories):
        features[stroke_names[i]] = np.array(t)


#    trajectories, stroke_names = autoenc_utils.group_strokewise(trajectories, stroke_names)

    return features, stroke_names
def extract_2DCNN_feats(DATASET, LABELS, CLASS_IDS, BATCH_SIZE, \
                        partition='all', nstrokes=-1):
    '''
    Extract sequence features from AutoEncoder.
    
    Parameters:
    -----------
    DATASET : str
        path to the video dataset
    LABELS : str
        path containing stroke labels
    CLASS_IDS : str
        path to txt file defining classes, similar to THUMOS
    BATCH_SIZE : int
        size for batch of clips
        Autoencoder.
    partition : str
        'all' / 'train' / 'test' / 'val' : Videos to be considered
    nstrokes : int
        partial extraction of features (if don't want to execute for entire dataset)
    
    Returns:
    --------
    trajectories, stroke_names
    
    '''
    
    ###########################################################################
    # Read the strokes 
    # Divide the highlight dataset files into training, validation and test sets
    train_lst, val_lst, test_lst = autoenc_utils.split_dataset_files(DATASET)
    print("No. of training videos : {}".format(len(train_lst)))
    
    #####################################################################
    
    if partition == 'all':
        partition_lst = train_lst
        partition_lst.extend(val_lst)
        partition_lst.extend(test_lst)
    elif partition == 'train':
        partition_lst = train_lst
    elif partition == 'val':
        partition_lst = val_lst
    elif partition == 'test':
        partition_lst = test_lst
        
    ###########################################################################
    # Create a Dataset    
    # Frame-wise transform
    clip_transform = transforms.Compose([transforms.ToPILImage(),
                                         transforms.Resize((224, 224)),
                                         transforms.ToTensor(),
                                         transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                                                              std=[0.229, 0.224, 0.225]),]) 
    # For using Frame level transform, the framewiseTransform flag turned on
    part_dataset = CricketStrokesDataset(partition_lst, DATASET, LABELS, CLASS_IDS, 
                                         frames_per_clip=1, train=True, 
                                         framewiseTransform=True,
                                         transform=clip_transform)
    
    data_loader = DataLoader(dataset=part_dataset, batch_size=BATCH_SIZE, shuffle=False)

    ###########################################################################
    # Extract using the data_loader
    stroke_names = []
    trajectories, stroke_traj = [], []
    num_strokes = 0
    extractor = Img2Vec()
    #INPUT_SIZE = extractor.layer_output_size
    prev_stroke = None
    
    print("Total Batches : {} :: BATCH_SIZE : {}".format(data_loader.__len__(), BATCH_SIZE))
    for bno, (inputs, vid_path, stroke, _) in enumerate(data_loader):
        # get video clips (B, SL, C, H, W)
        print("Batch No : {}".format(bno))
        # Extract spatial features using 2D ResNet
        inputs = torch.stack([extractor.get_vec(x) for x in inputs])
            
        # convert to start frames and end frames from tensors to lists
        stroke = [s.tolist() for s in stroke]
        inputs_lst, batch_stroke_names = autoenc_utils.separate_stroke_tensors(inputs, \
                                                                    vid_path, stroke)
        
        if bno == 0:
            prev_stroke = batch_stroke_names[0]
        
        for enc_idx, enc_input in enumerate(inputs_lst):
            # get no of sequences that can be extracted from enc_input tensor
            nSeqs = enc_input.size(0)
            if prev_stroke != batch_stroke_names[enc_idx]:
                # append old stroke to trajectories
                if len(stroke_traj) > 0:
                    num_strokes += 1
                    trajectories.append(stroke_traj)
                    stroke_names.append(prev_stroke)
                    stroke_traj = []
            
            # enc_input is same as enc_output while extraction of features.
            enc_output = enc_input
            enc_output = enc_output.squeeze(axis=1).cpu().data.numpy()
            # convert to [[[stroke1(size 32 each) ... ], [], ...], [ [], ... ]]
            stroke_traj.extend([enc_output[i, :] for i in range(enc_output.shape[0])])
            prev_stroke = batch_stroke_names[enc_idx]
            
        if nstrokes >= -1 and num_strokes == nstrokes:
            break
       
    # for last batch only if extracted for full dataset
    if len(stroke_traj) > 0 and nstrokes < 0:
        trajectories.append(stroke_traj)
        stroke_names.append(batch_stroke_names[-1])
        
    trajectories, stroke_names = autoenc_utils.group_strokewise(trajectories, stroke_names)
    #stroke_vecs, stroke_names =  aggregate_outputs(sequence_outputs, seq_stroke_names)    
    return trajectories, stroke_names
def extract_3DCNN_feats(DATASET, LABELS, CLASS_IDS, BATCH_SIZE, SEQ_SIZE=16, STEP=16, \
                        model_path=None, nclasses=5, partition='all', nstrokes=-1):
    '''
    Extract sequence features from AutoEncoder.
    
    Parameters:
    -----------
    DATASET : str
        path to the video dataset
    LABELS : str
        path containing stroke labels
    CLASS_IDS : str
        path to txt file defining classes, similar to THUMOS
    BATCH_SIZE : int
        size for batch of clips
    SEQ_SIZE : int
        no. of frames in a clip (min. 16 for 3D CNN extraction)
    STEP : int
        stride for next example. If SEQ_SIZE=16, STEP=8, use frames (0, 15), (8, 23) ...
    partition : str
        'all' / 'train' / 'test' / 'val' : Videos to be considered
    nstrokes : int
        partial extraction of features (do not execute for entire dataset)
    
    Returns:
    --------
    trajectories, stroke_names
    
    '''
    
    ###########################################################################
    # Read the strokes 
    # Divide the highlight dataset files into training, validation and test sets
    train_lst, val_lst, test_lst = autoenc_utils.split_dataset_files(DATASET)
    print("No. of training videos : {}".format(len(train_lst)))
    
    #####################################################################
    
    if partition == 'all':
        partition_lst = train_lst
        partition_lst.extend(val_lst)
        partition_lst.extend(test_lst)
    elif partition == 'train':
        partition_lst = train_lst
    elif partition == 'val':
        partition_lst = val_lst
    elif partition == 'test':
        partition_lst = test_lst
        
    ###########################################################################
    # Create a Dataset    
    # Clip level transform. Use this with framewiseTransform flag turned off
    clip_transform = transforms.Compose([videotransforms.CenterCrop(224),
                                         videotransforms.ToPILClip(), 
                                         videotransforms.Resize((112, 112)),
#                                         videotransforms.RandomCrop(112), 
                                         videotransforms.ToTensor(), 
                                         videotransforms.Normalize(),
                                        #videotransforms.RandomHorizontalFlip(),\
                                        ])

    part_dataset = CricketStrokesDataset(partition_lst, DATASET, LABELS, CLASS_IDS, 
                                         frames_per_clip=SEQ_SIZE, 
                                         step_between_clips=STEP, train=True, 
                                         framewiseTransform=False,
                                         transform=clip_transform)
    
    data_loader = DataLoader(dataset=part_dataset, batch_size=BATCH_SIZE, shuffle=False)

    ###########################################################################
    # Validate / Evaluate
    stroke_names = []
    trajectories, stroke_traj = [], []
    num_strokes = 0
    extractor = Clip2Vec(model_path, nclasses)
    #INPUT_SIZE = extractor.layer_output_size
    prev_stroke = None
    
    print("Total Batches : {} :: BATCH_SIZE : {}".format(data_loader.__len__(), BATCH_SIZE))
    assert SEQ_SIZE>=16, "SEQ_SIZE should be >= 16"
    for bno, (inputs, vid_path, stroke, _) in enumerate(data_loader):
        # get video clips (B, SL, C, H, W)
        print("Batch No : {}".format(bno))
        # Extract spatio-temporal features from clip using 3D ResNet (For SL >= 16)
        inputs = inputs.permute(0, 2, 1, 3, 4).float()
        inputs = extractor.get_vec(inputs)
        
        # convert to start frames and end frames from tensors to lists
        stroke = [s.tolist() for s in stroke]
        inputs_lst, batch_stroke_names = autoenc_utils.separate_stroke_tensors(inputs, \
                                                                    vid_path, stroke)
        
        if bno == 0:
            prev_stroke = batch_stroke_names[0]
        
        for enc_idx, enc_input in enumerate(inputs_lst):
            # get no of sequences that can be extracted from enc_input tensor
            nSeqs = enc_input.size(0)
            if prev_stroke != batch_stroke_names[enc_idx]:
                # append old stroke to trajectories
                if len(stroke_traj) > 0:
                    num_strokes += 1
                    trajectories.append(stroke_traj)
                    stroke_names.append(prev_stroke)
                    stroke_traj = []
            
            enc_output = enc_input
            enc_output = enc_output.squeeze(axis=1).cpu().data.numpy()
            # convert to [[[stroke1(size 32 each) ... ], [], ...], [ [], ... ]]
            stroke_traj.extend([enc_output[i, :] for i in range(enc_output.shape[0])])
            prev_stroke = batch_stroke_names[enc_idx]
            
                
        if nstrokes >=-1 and num_strokes == nstrokes:
            break
       
    # for last batch only if extracted for full dataset
    if len(stroke_traj) > 0 and nstrokes < 0:
        trajectories.append(stroke_traj)
        stroke_names.append(batch_stroke_names[-1])
        
    trajectories, stroke_names = autoenc_utils.group_strokewise(trajectories, stroke_names)
    #stroke_vecs, stroke_names =  aggregate_outputs(sequence_outputs, seq_stroke_names)
    return trajectories, stroke_names
예제 #6
0
def finetune_3DCNN(DATASET, LABELS, CLASS_IDS, BATCH_SIZE, ANNOTATION_FILE, 
                   SEQ_SIZE=16, STEP=16, nstrokes=-1, N_EPOCHS=25, base_name=""):
    '''
    Extract sequence features from AutoEncoder.
    
    Parameters:
    -----------
    DATASET : str
        path to the video dataset
    LABELS : str
        path containing stroke labels
    CLASS_IDS : str
        path to txt file defining classes, similar to THUMOS
    BATCH_SIZE : int
        size for batch of clips
    SEQ_SIZE : int
        no. of frames in a clip (min. 16 for 3D CNN extraction)
    STEP : int
        stride for next example. If SEQ_SIZE=16, STEP=8, use frames (0, 15), (8, 23) ...
    partition : str
        'all' / 'train' / 'test' / 'val' : Videos to be considered
    nstrokes : int
        partial extraction of features (do not execute for entire dataset)
    
    Returns:
    --------
    trajectories, stroke_names
    
    '''
    
    ###########################################################################
    # Read the strokes 
    # Divide the highlight dataset files into training, validation and test sets
    train_lst, val_lst, test_lst = autoenc_utils.split_dataset_files(DATASET)
    print("No. of training videos : {}".format(len(train_lst)))
        
    ###########################################################################
    # Create a Dataset    
    # Clip level transform. Use this with framewiseTransform flag turned off
    clip_transform = transforms.Compose([videotransforms.CenterCrop(224),
                                         videotransforms.ToPILClip(), 
                                         videotransforms.Resize((112, 112)),
#                                         videotransforms.RandomCrop(112), 
                                         videotransforms.ToTensor(), 
                                         videotransforms.Normalize(),
                                        #videotransforms.RandomHorizontalFlip(),\
                                        ])

    train_dataset = CricketStrokesDataset(train_lst, DATASET, LABELS, CLASS_IDS, 
                                         frames_per_clip=SEQ_SIZE, 
                                         step_between_clips=STEP, train=True, 
                                         framewiseTransform=False,
                                         transform=clip_transform)
    val_dataset = CricketStrokesDataset(val_lst, DATASET, LABELS, CLASS_IDS, 
                                         frames_per_clip=SEQ_SIZE, 
                                         step_between_clips=STEP, train=False, 
                                         framewiseTransform=False,
                                         transform=clip_transform)
    
    train_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    
    val_loader = DataLoader(dataset=val_dataset, batch_size=BATCH_SIZE, shuffle=False)
    
    data_loaders = {"train": train_loader, "test": val_loader}

    ###########################################################################
    
    labs_keys, labs_values = get_cluster_labels(ANNOTATION_FILE)
    
    num_classes = len(list(set(labs_values)))
    
    ###########################################################################    
    # load model and set loss function
    model = torchvision.models.video.r3d_18(pretrained=True, progress=True)
    
    for ft in model.parameters():
        ft.requires_grad = False
    
    inp_feat_size = model.fc.in_features
    model.fc = nn.Linear(inp_feat_size, num_classes)
    model = model.to(device)
    
    # load checkpoint:
    if os.path.isfile(os.path.join(base_name, "3dresnet18_ep"+str(N_EPOCHS)+"_Adam.pt")):
        model.load_state_dict(torch.load(os.path.join(base_name, "3dresnet18_ep"+str(N_EPOCHS)+"_Adam.pt")))
    
    # Setup the loss fxn
    criterion = nn.CrossEntropyLoss()
    
    # Layers to finetune. Last layer should be displayed
    params_to_update = model.parameters()
    print("Params to learn:")
    
    params_to_update = []
    for name, param in model.named_parameters():
        if param.requires_grad == True:
            params_to_update.append(param)
            print("\t",name)
    
    # Observe that all parameters are being optimized
    optimizer_ft = torch.optim.Adam(params_to_update, lr=0.001)
    
    # Decay LR by a factor of 0.1 every 7 epochs
    exp_lr_scheduler = StepLR(optimizer_ft, step_size=8, gamma=0.1)
    
#    # Observe that all parameters are being optimized
#    optimizer_ft = optim.SGD(params_to_update, lr=0.001, momentum=0.9)
    
#    ###########################################################################
#    # Training the model    
#    
#    start = time.time()
#    
#    model_ft = train_model(model, data_loaders, criterion, optimizer_ft, 
#                           exp_lr_scheduler, labs_keys, labs_values,
#                           num_epochs=N_EPOCHS)
#        
#    end = time.time()
#    
#    # save the best performing model
#    save_model_checkpoint(base_name, model_ft, N_EPOCHS, "Adam")
#    
#    print("Total Execution time for {} epoch : {}".format(N_EPOCHS, (end-start)))
    
    ###########################################################################
    # Validate / Evaluate
    stroke_names = []
    trajectories, stroke_traj = [], []
    num_strokes = 0
    model = model.eval()
#    extractor = Clip2Vec()
    #INPUT_SIZE = extractor.layer_output_size
    prev_stroke = None
    
    print("Total Batches : {} :: BATCH_SIZE : {}".format(data_loaders['test'].__len__(), BATCH_SIZE))
    assert SEQ_SIZE>=16, "SEQ_SIZE should be >= 16"
    for bno, (inputs, vid_path, stroke, _) in enumerate(data_loaders['test']):
        # get video clips (B, SL, C, H, W)
        print("Batch No : {}".format(bno))
        # Extract spatio-temporal features from clip using 3D ResNet (For SL >= 16)
        inputs = inputs.permute(0, 2, 1, 3, 4).float()
        inputs = extractor.get_vec(inputs)
        
        # convert to start frames and end frames from tensors to lists
        stroke = [s.tolist() for s in stroke]
        inputs_lst, batch_stroke_names = autoenc_utils.separate_stroke_tensors(inputs, \
                                                                    vid_path, stroke)
        
        if bno == 0:
            prev_stroke = batch_stroke_names[0]
        
        for enc_idx, enc_input in enumerate(inputs_lst):
            # get no of sequences that can be extracted from enc_input tensor
            nSeqs = enc_input.size(0)
            if prev_stroke != batch_stroke_names[enc_idx]:
                # append old stroke to trajectories
                if len(stroke_traj) > 0:
                    num_strokes += 1
                    trajectories.append(stroke_traj)
                    stroke_names.append(prev_stroke)
                    stroke_traj = []
            
            enc_output = enc_input
            enc_output = enc_output.squeeze(axis=1).cpu().data.numpy()
            # convert to [[[stroke1(size 32 each) ... ], [], ...], [ [], ... ]]
            stroke_traj.extend([enc_output[i, :] for i in range(enc_output.shape[0])])
            prev_stroke = batch_stroke_names[enc_idx]
            
                
        if nstrokes >=-1 and num_strokes == nstrokes:
            break
       
    # for last batch only if extracted for full dataset
    if len(stroke_traj) > 0 and nstrokes < 0:
        trajectories.append(stroke_traj)
        stroke_names.append(batch_stroke_names[-1])
        
    trajectories, stroke_names = autoenc_utils.group_strokewise(trajectories, stroke_names)
    #stroke_vecs, stroke_names =  aggregate_outputs(sequence_outputs, seq_stroke_names)
    return trajectories, stroke_names
예제 #7
0
def main(DATASET,
         LABELS,
         CLASS_IDS,
         BATCH_SIZE,
         ANNOTATION_FILE,
         SEQ_SIZE=16,
         STEP=16,
         nstrokes=-1,
         N_EPOCHS=25,
         base_name=""):
    '''
    Extract sequence features from AutoEncoder.
    
    Parameters:
    -----------
    DATASET : str
        path to the video dataset
    LABELS : str
        path containing stroke labels
    CLASS_IDS : str
        path to txt file defining classes, similar to THUMOS
    BATCH_SIZE : int
        size for batch of clips
    SEQ_SIZE : int
        no. of frames in a clip (min. 16 for 3D CNN extraction)
    STEP : int
        stride for next example. If SEQ_SIZE=16, STEP=8, use frames (0, 15), (8, 23) ...
    partition : str
        'all' / 'train' / 'test' / 'val' : Videos to be considered
    nstrokes : int
        partial extraction of features (do not execute for entire dataset)
    
    Returns:
    --------
    trajectories, stroke_names
    
    '''
    ###########################################################################
    # seed everything
    seed = 1234
    attn_utils.seed_everything(seed)
    if not os.path.isdir(base_name):
        os.makedirs(base_name)

    # Read the strokes
    # Divide the highlight dataset files into training, validation and test sets
    train_lst, val_lst, test_lst = autoenc_utils.split_dataset_files(DATASET)
    print("No. of training videos : {}".format(len(train_lst)))

    ###########################################################################
    # Create a Dataset
    # Clip level transform. Use this with framewiseTransform flag turned off
    train_transform = transforms.Compose([
        videotransforms.RandomCrop(224),
        videotransforms.ToPILClip(),
        videotransforms.Resize((112, 112)),
        videotransforms.ToTensor(),
        videotransforms.Normalize(),
        #videotransforms.RandomHorizontalFlip(),\
    ])
    test_transform = transforms.Compose([
        videotransforms.CenterCrop(224),
        videotransforms.ToPILClip(),
        videotransforms.Resize((112, 112)),
        videotransforms.ToTensor(),
        videotransforms.Normalize(),
        #videotransforms.RandomHorizontalFlip(),\
    ])
    train_dataset = CricketStrokesDataset(train_lst,
                                          DATASET,
                                          LABELS,
                                          CLASS_IDS,
                                          frames_per_clip=SEQ_SIZE,
                                          step_between_clips=STEP,
                                          train=True,
                                          framewiseTransform=False,
                                          transform=train_transform)
    val_dataset = CricketStrokesDataset(val_lst,
                                        DATASET,
                                        LABELS,
                                        CLASS_IDS,
                                        frames_per_clip=SEQ_SIZE,
                                        step_between_clips=STEP,
                                        train=False,
                                        framewiseTransform=False,
                                        transform=test_transform)

    ###########################################################################

    labs_keys, labs_values = attn_utils.get_cluster_labels(ANNOTATION_FILE)

    num_classes = len(list(set(labs_values)))

    # created weighted Sampler for class imbalance
    if not os.path.isfile(
            os.path.join(
                base_name, "weights_c" + str(num_classes) + "_" +
                str(len(train_dataset)) + ".pkl")):
        samples_weight = attn_utils.get_sample_weights(train_dataset,
                                                       labs_keys, labs_values,
                                                       train_lst)
        with open(
                os.path.join(
                    base_name, "weights_c" + str(num_classes) + "_" +
                    str(len(train_dataset)) + ".pkl"), "wb") as fp:
            pickle.dump(samples_weight, fp)
    with open(
            os.path.join(
                base_name, "weights_c" + str(num_classes) + "_" +
                str(len(train_dataset)) + ".pkl"), "rb") as fp:
        samples_weight = pickle.load(fp)
    sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
    train_loader = DataLoader(dataset=train_dataset,
                              batch_size=BATCH_SIZE,
                              sampler=sampler,
                              worker_init_fn=np.random.seed(12))

    val_loader = DataLoader(dataset=val_dataset,
                            batch_size=BATCH_SIZE,
                            shuffle=False)

    data_loaders = {"train": train_loader, "test": val_loader}

    ###########################################################################
    # load model and set loss function
    encoder = conv_attn_model.Conv3DEncoder(HIDDEN_SIZE, 1, bidirectional)
    #    encoder = conv_attn_model.Conv3DAttention(HIDDEN_SIZE, num_classes, 1, 196, bidirectional)
    decoder = conv_attn_model.Conv3DDecoder(HIDDEN_SIZE, num_classes, 1, 1,
                                            bidirectional)
    #    decoder = conv_encdec_model.Conv3DDecoder(HIDDEN_SIZE, HIDDEN_SIZE, 1, 196, bidirectional)
    #    model = attn_model.Encoder(10, 20, bidirectional)

    #    for ft in model.parameters():
    #        ft.requires_grad = False
    #    inp_feat_size = model.fc.in_features
    #    model.fc = nn.Linear(inp_feat_size, num_classes)
    #    model = model.to(device)
    encoder = encoder.to(device)
    decoder = decoder.to(device)
    #    # load checkpoint:

    # Setup the loss fxn
    criterion = nn.CrossEntropyLoss()
    #    criterion = nn.MSELoss()

    #    # Layers to finetune. Last layer should be displayed
    print("Params to learn:")
    params_to_update = []
    for name, param in encoder.named_parameters():
        if param.requires_grad == True:
            params_to_update.append(param)
            print("Encoder : {}".format(name))
    for name, param in decoder.named_parameters():
        if param.requires_grad == True:
            params_to_update.append(param)
            print("Decoder : {}".format(name))

    # Observe that all parameters are being optimized


#    optimizer_ft = torch.optim.Adam(params_to_update, lr=0.001)
#    optimizer_ft = torch.optim.SGD(params_to_update, lr=0.01, momentum=0.9)
    encoder_optimizer = torch.optim.SGD(encoder.parameters(),
                                        lr=0.01,
                                        momentum=0.9)
    decoder_optimizer = torch.optim.SGD(decoder.parameters(),
                                        lr=0.01,
                                        momentum=0.9)
    #    decoder_optimizer = None

    # Decay LR by a factor of 0.1 every 7 epochs
    lr_scheduler = StepLR(encoder_optimizer, step_size=10, gamma=0.1)

    #    # Observe that all parameters are being optimized
    #    optimizer_ft = optim.SGD(params_to_update, lr=0.001, momentum=0.9)

    #    ###########################################################################
    # Training the model
    start = time.time()

    (encoder, decoder) = train_model(encoder,
                                     decoder,
                                     data_loaders,
                                     criterion,
                                     encoder_optimizer,
                                     decoder_optimizer,
                                     lr_scheduler,
                                     labs_keys,
                                     labs_values,
                                     num_epochs=N_EPOCHS)

    end = time.time()

    # save the best performing model
    attn_utils.save_attn_model_checkpoint(base_name, (encoder, decoder),
                                          N_EPOCHS, "SGD")
    # Load model checkpoints
    encoder, decoder = attn_utils.load_attn_model_checkpoint(
        base_name, encoder, decoder, N_EPOCHS, "SGD")

    print("Total Execution time for {} epoch : {}".format(
        N_EPOCHS, (end - start)))

    ###########################################################################

    #    features_val, stroke_names_id_val = attn_utils.read_feats(os.path.join(base_name, ft_dir),
    #                                                              feat_val, snames_val)
    print("Writing prediction dictionary....")
    pred_out_dict, acc = predict(encoder,
                                 decoder,
                                 data_loaders,
                                 criterion,
                                 labs_keys,
                                 labs_values,
                                 phase='test')

    with open(os.path.join(base_name, "pred_dict.pkl"), "wb") as fp:
        pickle.dump(pred_out_dict, fp)

    # save the output wts and related information
    print("#Parameters Encoder : {} ".format(
        autoenc_utils.count_parameters(encoder)))
    print("#Parameters Decoder : {} ".format(
        autoenc_utils.count_parameters(decoder)))

    return encoder, decoder
예제 #8
0
def main(DATASET,
         LABELS,
         CLASS_IDS,
         BATCH_SIZE,
         ANNOTATION_FILE,
         SEQ_SIZE=16,
         STEP=16,
         nstrokes=-1,
         N_EPOCHS=25,
         base_name=""):
    '''
    Extract sequence features from AutoEncoder.
    
    Parameters:
    -----------
    DATASET : str
        path to the video dataset
    LABELS : str
        path containing stroke labels
    CLASS_IDS : str
        path to txt file defining classes, similar to THUMOS
    BATCH_SIZE : int
        size for batch of clips
    SEQ_SIZE : int
        no. of frames in a clip (min. 16 for 3D CNN extraction)
    STEP : int
        stride for next example. If SEQ_SIZE=16, STEP=8, use frames (0, 15), (8, 23) ...
    partition : str
        'all' / 'train' / 'test' / 'val' : Videos to be considered
    nstrokes : int
        partial extraction of features (do not execute for entire dataset)
    
    Returns:
    --------
    trajectories, stroke_names
    
    '''
    if not os.path.isdir(base_name):
        os.makedirs(base_name)
    seed = 1234
    attn_utils.seed_everything(seed)
    ###########################################################################
    # Read the strokes
    # Divide the highlight dataset files into training, validation and test sets
    train_lst, val_lst, test_lst = autoenc_utils.split_dataset_files(DATASET)
    print("No. of training videos : {}".format(len(train_lst)))

    ###########################################################################
    # Create a Dataset
    # Clip level transform. Use this with framewiseTransform flag turned off
    train_transforms = transforms.Compose([
        videotransforms.RandomCrop(300),
        videotransforms.ToPILClip(),
        videotransforms.Resize((112, 112)),
        videotransforms.ToTensor(),
        videotransforms.Normalize(),
        #                                           videotransforms.ScaledNormMinMax(),
    ])
    test_transforms = transforms.Compose([
        videotransforms.CenterCrop(300),
        videotransforms.ToPILClip(),
        videotransforms.Resize((112, 112)),
        videotransforms.ToTensor(),
        videotransforms.Normalize(),
        #                                          videotransforms.ScaledNormMinMax(),
    ])
    train_dataset = CricketStrokesDataset(train_lst,
                                          DATASET,
                                          LABELS,
                                          CLASS_IDS,
                                          frames_per_clip=SEQ_SIZE,
                                          step_between_clips=STEP,
                                          train=True,
                                          framewiseTransform=False,
                                          transform=train_transforms)
    val_dataset = CricketStrokesDataset(val_lst,
                                        DATASET,
                                        LABELS,
                                        CLASS_IDS,
                                        frames_per_clip=SEQ_SIZE,
                                        step_between_clips=STEP,
                                        train=False,
                                        framewiseTransform=False,
                                        transform=test_transforms)

    ###########################################################################

    labs_keys, labs_values = attn_utils.get_cluster_labels(ANNOTATION_FILE)

    num_classes = len(list(set(labs_values)))

    # created weighted Sampler for class imbalance
    if not os.path.isfile(
            os.path.join(
                base_name, "weights_c" + str(num_classes) + "_" +
                str(len(train_dataset)) + ".pkl")):
        samples_weight = attn_utils.get_sample_weights(train_dataset,
                                                       labs_keys, labs_values,
                                                       train_lst)
        with open(
                os.path.join(
                    base_name, "weights_c" + str(num_classes) + "_" +
                    str(len(train_dataset)) + ".pkl"), "wb") as fp:
            pickle.dump(samples_weight, fp)
    with open(
            os.path.join(
                base_name, "weights_c" + str(num_classes) + "_" +
                str(len(train_dataset)) + ".pkl"), "rb") as fp:
        samples_weight = pickle.load(fp)
    sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
    train_loader = DataLoader(dataset=train_dataset,
                              batch_size=BATCH_SIZE,
                              sampler=sampler,
                              worker_init_fn=np.random.seed(12))

    val_loader = DataLoader(dataset=val_dataset,
                            batch_size=BATCH_SIZE,
                            shuffle=False)

    data_loaders = {"train": train_loader, "test": val_loader}

    ###########################################################################
    # load model and set loss function
    model = conv_attn_model.C3DGRUv2Orig(HIDDEN_SIZE, 1, num_classes,
                                         bidirectional)
    model_pretrained = c3d.C3D()
    model_pretrained.load_state_dict(
        torch.load("../localization_rnn/" + wts_path))
    #    model_pretrained = c3d_pre.C3D()
    #    model_pretrained.fc8 = nn.Linear(4096, 5)
    #    model_pretrained.load_state_dict(torch.load(pretrained_c3d_wts))
    copy_pretrained_weights(model_pretrained, model)
    # reset the last layer (default requires_grad is True)
    #    model.conv5b = nn.Conv3d(512, 512, kernel_size=(3, 3, 3), padding=(1, 1, 1))
    #    for ft in model.parameters():
    #        ft.requires_grad = False
    #    inp_feat_size = model.fc.in_features
    #    model.fc = nn.Linear(inp_feat_size, num_classes)
    model = model.to(device)

    # Setup the loss fxn
    criterion = nn.CrossEntropyLoss()
    #    criterion = nn.MSELoss()

    #    # Layers to finetune. Last layer should be displayed
    print("Params to learn:")
    params_to_update = []
    for name, param in model.named_parameters():
        if param.requires_grad == True:
            params_to_update.append(param)
            print("\t {}".format(name))

    # Observe that all parameters are being optimized


#    optimizer_ft = torch.optim.Adam(params_to_update, lr=0.01)
    optimizer_ft = optim.SGD(params_to_update, lr=0.001, momentum=0.9)

    # Decay LR by a factor of 0.1 every 7 epochs
    lr_scheduler = StepLR(optimizer_ft, step_size=30, gamma=0.1)

    ###########################################################################
    # Training the model
    start = time.time()

    model = train_model(model,
                        data_loaders,
                        criterion,
                        optimizer_ft,
                        lr_scheduler,
                        labs_keys,
                        labs_values,
                        num_epochs=N_EPOCHS)

    end = time.time()

    # save the best performing model
    attn_utils.save_model_checkpoint(base_name, model, N_EPOCHS,
                                     "SGD_c8_c3dgruEp60Step30")
    # Load model checkpoints
    model = attn_utils.load_weights(base_name, model, N_EPOCHS,
                                    "SGD_c8_c3dgruEp60Step30")

    print("Total Execution time for {} epoch : {}".format(
        N_EPOCHS, (end - start)))

    #    ###########################################################################

    print("Predicting ...")
    acc = predict(model, data_loaders, labs_keys, labs_values, phase='test')

    print("#Parameters : {} ".format(autoenc_utils.count_parameters(model)))

    return model