def main(DATASET, LABELS, CLASS_IDS, BATCH_SIZE, ANNOTATION_FILE, SEQ_SIZE=16, 
         STEP=16, nstrokes=-1, N_EPOCHS=25):
    '''
    Extract sequence features from AutoEncoder.
    
    Parameters:
    -----------
    DATASET : str
        path to the video dataset
    LABELS : str
        path containing stroke labels
    CLASS_IDS : str
        path to txt file defining classes, similar to THUMOS
    BATCH_SIZE : int
        size for batch of clips
    SEQ_SIZE : int
        no. of frames in a clip (min. 16 for 3D CNN extraction)
    STEP : int
        stride for next example. If SEQ_SIZE=16, STEP=8, use frames (0, 15), (8, 23) ...
    partition : str
        'all' / 'train' / 'test' / 'val' : Videos to be considered
    nstrokes : int
        partial extraction of features (do not execute for entire dataset)
    
    Returns:
    --------
    trajectories, stroke_names
    
    '''
    ###########################################################################
    
    attn_utils.seed_everything(1234)
    
    if not os.path.isdir(log_path):
        os.makedirs(log_path)
    
    # Read the strokes 
    # Divide the highlight dataset files into training, validation and test sets
    train_lst, val_lst, test_lst = autoenc_utils.split_dataset_files(DATASET)
    print("No. of training videos : {}".format(len(train_lst)))
    
#    extract_of_features(feat_path, DATASET, LABELS, train_lst, val_lst)
    
    features, stroke_names_id = attn_utils.read_feats(feat_path, feat, snames)
    # get matrix of features from dictionary (N, vec_size)
    vecs = []
    for key in sorted(list(features.keys())):
        vecs.append(features[key])
    vecs = np.vstack(vecs)
    
    vecs[np.isnan(vecs)] = 0
    vecs[np.isinf(vecs)] = 0
    
    #fc7 layer output size (4096) 
    INP_VEC_SIZE = vecs.shape[-1]
    print("INP_VEC_SIZE = ", INP_VEC_SIZE)
    
    km_filepath = os.path.join(log_path, km_filename)
#    # Uncomment only while training.
    if not os.path.isfile(km_filepath+"_C"+str(cluster_size)+".pkl"):
        km_model = make_codebook(vecs, cluster_size)  #, model_type='gmm') 
        ##    # Save to disk, if training is performed
        print("Writing the KMeans models to disk...")
        pickle.dump(km_model, open(km_filepath+"_C"+str(cluster_size)+".pkl", "wb"))
    else:
        # Load from disk, for validation and test sets.
        km_model = pickle.load(open(km_filepath+"_C"+str(cluster_size)+".pkl", 'rb'))
        
    print("Create numpy one hot representation for train features...")
    onehot_feats = create_bovw_onehot(features, stroke_names_id, km_model)
    
    ft_path = os.path.join(log_path, "C"+str(cluster_size)+"_train.pkl")
    with open(ft_path, "wb") as fp:
        pickle.dump(onehot_feats, fp)
    
    ###########################################################################
    
    features_val, stroke_names_id_val = attn_utils.read_feats(feat_path, feat_val, 
                                                              snames_val)
    
    print("Create numpy one hot representation for val features...")
    onehot_feats_val = create_bovw_onehot(features_val, stroke_names_id_val, km_model)
    
    ft_path_val = os.path.join(log_path, "C"+str(cluster_size)+"_val.pkl")
    with open(ft_path_val, "wb") as fp:
        pickle.dump(onehot_feats_val, fp)
    
    ###########################################################################
    
    features_test, stroke_names_id_test = attn_utils.read_feats(feat_path, feat_test, 
                                                                snames_test)
    
    print("Create numpy one hot representation for test features...")
    onehot_feats_test = create_bovw_onehot(features_test, stroke_names_id_test, km_model)
    
    ft_path_test = os.path.join(log_path, "C"+str(cluster_size)+"_test.pkl")
    with open(ft_path_test, "wb") as fp:
        pickle.dump(onehot_feats_test, fp)
    
    ###########################################################################    
    # Create a Dataset    
    train_dataset = StrokeFeaturePairsDataset(ft_path, train_lst, DATASET, LABELS, CLASS_IDS, 
                                         frames_per_clip=SEQ_SIZE, extracted_frames_per_clip=2,
                                         step_between_clips=STEP, train=True)
    val_dataset = StrokeFeaturePairsDataset(ft_path_val, val_lst, DATASET, LABELS, CLASS_IDS, 
                                         frames_per_clip=SEQ_SIZE, extracted_frames_per_clip=2,
                                         step_between_clips=STEP, train=False)
    
    # get labels
    labs_keys, labs_values = attn_utils.get_cluster_labels(ANNOTATION_FILE)
    # created weighted Sampler for class imbalance
    samples_weight = attn_utils.get_sample_weights(train_dataset, labs_keys, labs_values, 
                                                   train_lst)
    sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
    
    train_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, #shuffle=True,
                              sampler=sampler, worker_init_fn=np.random.seed(12))
    
    val_loader = DataLoader(dataset=val_dataset, batch_size=BATCH_SIZE, shuffle=False)
    
    data_loaders = {"train": train_loader, "test": val_loader}

    num_classes = len(list(set(labs_values)))
    
    ###########################################################################    
    
    # load model and set loss function
    ntokens = cluster_size # the size of vocabulary
    emsize = 200 # embedding dimension
    nhid = 200 # the dimension of the feedforward network model in nn.TransformerEncoder
    nlayers = 2 # the number of nn.TransformerEncoderLayer in nn.TransformerEncoder
    nhead = 2 # the number of heads in the multiheadattention models
    dropout = 0.2 # the dropout value
    model = tt.TransformerModel(ntokens, emsize, nhead, nhid, nlayers, #num_classes, 
                                dropout)
#    model = load_weights(model_path, model, N_EPOCHS, 
#                                    "S30"+"C"+str(cluster_size)+"_SGD")
    
    # copy the pretrained weights 
    
    # Setup the loss fxn
    criterion = nn.CrossEntropyLoss()
    
    model.decoder = nn.Linear(model.ninp, num_classes)
    initrange = 0.1
    model.decoder.bias.data.zero_()
    model.decoder.weight.data.uniform_(-initrange, initrange)

    model = model.to(device)

    print("Params to learn:")
    for name, param in model.named_parameters():
        if param.requires_grad == True:
            print("\t", name)

    
#    # Observe that all parameters are being optimized
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
#    optimizer_ft = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
#    
#    # Decay LR by a factor of 0.1 every 7 epochs
    scheduler = StepLR(optimizer, step_size=15, gamma=0.1)
    
#    lr = 5.0 # learning rate
#    optimizer = torch.optim.SGD(model.parameters(), lr=lr)
#    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95)
    ###########################################################################
    # Training the model    
    
    start = time.time()
    
    model = train_model(features, stroke_names_id, model, data_loaders, criterion, 
                        optimizer, scheduler, labs_keys, labs_values,
                        num_epochs=N_EPOCHS)
    
    end = time.time()
    
    # save the best performing model
    save_model_checkpoint(log_path, model, N_EPOCHS, 
                                     "S"+str(SEQ_SIZE)+"C"+str(cluster_size)+"_SGD")
    # Load model checkpoints
    model = load_weights(log_path, model, N_EPOCHS, 
                                    "S"+str(SEQ_SIZE)+"C"+str(cluster_size)+"_SGD")
    
    print("Total Execution time for {} epoch : {}".format(N_EPOCHS, (end-start)))

    ###########################################################################
    
    acc = predict(features_val, stroke_names_id_val, model, data_loaders, labs_keys, 
                  labs_values, SEQ_SIZE, phase='test')
    
    ###########################################################################    
            
    # call count_paramters(model)  for displaying total no. of parameters
    print("#Parameters : {} ".format(autoenc_utils.count_parameters(model)))
    return 0
def main(DATASET,
         LABELS,
         CLASS_IDS,
         BATCH_SIZE,
         ANNOTATION_FILE,
         SEQ_SIZE=16,
         STEP=16,
         nstrokes=-1,
         N_EPOCHS=25):
    '''
    Extract sequence features from AutoEncoder.
    
    Parameters:
    -----------
    DATASET : str
        path to the video dataset
    LABELS : str
        path containing stroke labels
    CLASS_IDS : str
        path to txt file defining classes, similar to THUMOS
    BATCH_SIZE : int
        size for batch of clips
    SEQ_SIZE : int
        no. of frames in a clip (min. 16 for 3D CNN extraction)
    STEP : int
        stride for next example. If SEQ_SIZE=16, STEP=8, use frames (0, 15), (8, 23) ...
    partition : str
        'all' / 'train' / 'test' / 'val' : Videos to be considered
    nstrokes : int
        partial extraction of features (do not execute for entire dataset)
    
    Returns:
    --------
    trajectories, stroke_names
    
    '''
    ###########################################################################

    attn_utils.seed_everything(1234)

    if not os.path.isdir(log_path):
        os.makedirs(log_path)

    # Read the strokes
    # Divide the highlight dataset files into training, validation and test sets
    train_lst, val_lst, test_lst = autoenc_utils.split_dataset_files(DATASET)
    print("No. of training videos : {}".format(len(train_lst)))

    #    extract_of_features(feat_path, DATASET, LABELS, train_lst, val_lst)

    features, stroke_names_id = attn_utils.read_feats(feat_path, feat, snames)
    # get matrix of features from dictionary (N, vec_size)
    vecs = []
    for key in sorted(list(features.keys())):
        vecs.append(features[key])
    vecs = np.vstack(vecs)

    vecs[np.isnan(vecs)] = 0
    vecs[np.isinf(vecs)] = 0

    #    avg, std = np.mean(vecs, axis=0), np.std(vecs, axis=0)
    #    vecs = (vecs - avg) / std
    #    features = normalize_feats(features, avg, std)

    #fc7 layer output size (4096)
    INP_VEC_SIZE = vecs.shape[-1]
    print("INP_VEC_SIZE = ", INP_VEC_SIZE)

    km_filepath = os.path.join(log_path, km_filename)
    #    # Uncomment only while training.
    if not os.path.isfile(km_filepath + "_C" + str(cluster_size) + ".pkl"):
        km_model = make_codebook(vecs, cluster_size)  #, model_type='gmm')
        ##    # Save to disk, if training is performed
        print("Writing the KMeans models to disk...")
        pickle.dump(
            km_model,
            open(km_filepath + "_C" + str(cluster_size) + ".pkl", "wb"))
    else:
        # Load from disk, for validation and test sets.
        km_model = pickle.load(
            open(km_filepath + "_C" + str(cluster_size) + ".pkl", 'rb'))

    print("Create numpy one hot representation for train features...")
    onehot_feats = create_bovw_onehot(features, stroke_names_id, km_model)

    ft_path = os.path.join(log_path, "C" + str(cluster_size) + "_train.pkl")
    with open(ft_path, "wb") as fp:
        pickle.dump(onehot_feats, fp)
    with open(
            os.path.join(log_path,
                         "C" + str(cluster_size) + "_snames_train.pkl"),
            "wb") as fp:
        pickle.dump(stroke_names_id, fp)
    ###########################################################################

    features_val, stroke_names_id_val = attn_utils.read_feats(
        feat_path, feat_val, snames_val)
    #    features_val = normalize_feats(features_val, avg, std)

    print("Create numpy one hot representation for val features...")
    onehot_feats_val = create_bovw_onehot(features_val, stroke_names_id_val,
                                          km_model)

    ft_path_val = os.path.join(log_path, "C" + str(cluster_size) + "_val.pkl")
    with open(ft_path_val, "wb") as fp:
        pickle.dump(onehot_feats_val, fp)
    with open(
            os.path.join(log_path,
                         "C" + str(cluster_size) + "_snames_val.pkl"),
            "wb") as fp:
        pickle.dump(stroke_names_id_val, fp)
    ###########################################################################
    features_test, stroke_names_id_test = attn_utils.read_feats(
        feat_path, feat_test, snames_test)
    #    features_test = normalize_feats(features_test, avg, std)
    print("Create numpy one hot representation for val features...")
    onehot_feats_test = create_bovw_onehot(features_test, stroke_names_id_test,
                                           km_model)

    ft_path_test = os.path.join(log_path,
                                "C" + str(cluster_size) + "_test.pkl")
    with open(ft_path_test, "wb") as fp:
        pickle.dump(onehot_feats_test, fp)
    with open(
            os.path.join(log_path,
                         "C" + str(cluster_size) + "_snames_test.pkl"),
            "wb") as fp:
        pickle.dump(stroke_names_id_test, fp)

    ###########################################################################
    # Create a Dataset

#    ft_path = os.path.join(base_name, ft_dir, feat)
    train_dataset = StrokeFeatureSequenceDataset(ft_path,
                                                 train_lst,
                                                 DATASET,
                                                 LABELS,
                                                 CLASS_IDS,
                                                 frames_per_clip=SEQ_SIZE,
                                                 extracted_frames_per_clip=16,
                                                 step_between_clips=STEP,
                                                 train=True)
    #    ft_path_val = os.path.join(base_name, ft_dir, feat_val)
    val_dataset = StrokeFeatureSequenceDataset(ft_path_test,
                                               test_lst,
                                               DATASET,
                                               LABELS,
                                               CLASS_IDS,
                                               frames_per_clip=SEQ_SIZE,
                                               extracted_frames_per_clip=16,
                                               step_between_clips=STEP,
                                               train=False)

    # get labels
    labs_keys, labs_values = attn_utils.get_cluster_labels(ANNOTATION_FILE)
    # created weighted Sampler for class imbalance
    samples_weight = attn_utils.get_sample_weights(train_dataset, labs_keys,
                                                   labs_values, train_lst)
    sampler = WeightedRandomSampler(samples_weight, len(samples_weight))

    train_loader = DataLoader(dataset=train_dataset,
                              batch_size=BATCH_SIZE,
                              sampler=sampler,
                              worker_init_fn=np.random.seed(12))

    val_loader = DataLoader(dataset=val_dataset,
                            batch_size=BATCH_SIZE,
                            shuffle=False)

    data_loaders = {"train": train_loader, "test": val_loader}

    num_classes = len(list(set(labs_values)))

    #    vis_clusters(features, onehot_feats, stroke_names_id, 2, DATASET, log_path)

    ###########################################################################

    # load model and set loss function
    model = attn_model.GRUBoWHAClassifier(INPUT_SIZE, HIDDEN_SIZE, num_classes,
                                          N_LAYERS, bidirectional)

    #    model = load_weights(base_name, model, N_EPOCHS, "Adam")

    #    for ft in model.parameters():
    #        ft.requires_grad = False

    # Setup the loss fxn
    criterion = nn.CrossEntropyLoss()
    model = model.to(device)
    #    print("Params to learn:")
    params_to_update = []
    for name, param in model.named_parameters():
        if param.requires_grad == True:
            params_to_update.append(param)


#            print("\t",name)

# Observe that all parameters are being optimized
#    optimizer_ft = torch.optim.Adam(model.parameters(), lr=0.001)
    optimizer_ft = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

    # Decay LR by a factor of 0.1 every 7 epochs
    exp_lr_scheduler = StepLR(optimizer_ft, step_size=10, gamma=0.1)

    ###########################################################################
    # Training the model

    start = time.time()

    #    model = train_model(features, stroke_names_id, model, data_loaders, criterion,
    #                        optimizer_ft, exp_lr_scheduler, labs_keys, labs_values,
    #                        num_epochs=N_EPOCHS)

    end = time.time()

    #    # save the best performing model
    #    attn_utils.save_model_checkpoint(log_path, model, N_EPOCHS,
    #                                     "S"+str(SEQ_SIZE)+"C"+str(cluster_size)+"_SGD")
    # Load model checkpoints
    model = attn_utils.load_weights(
        log_path, model, N_EPOCHS,
        "S" + str(SEQ_SIZE) + "C" + str(cluster_size) + "_SGD")

    print("Total Execution time for {} epoch : {}".format(
        N_EPOCHS, (end - start)))

    #    ###########################################################################

    acc = predict(features_test,
                  stroke_names_id_test,
                  model,
                  data_loaders,
                  labs_keys,
                  labs_values,
                  SEQ_SIZE,
                  phase='test')

    # call count_paramters(model)  for displaying total no. of parameters
    print("#Parameters : {} ".format(autoenc_utils.count_parameters(model)))
    return acc