예제 #1
0
def main():
    cudnn.benchmark = False
    cur_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time()))
    logdir = os.path.join(params['log'], cur_time)
    if not os.path.exists(logdir):
        os.makedirs(logdir)

    writer = SummaryWriter(log_dir=logdir)

    print("Loading dataset")
    train_dataloader = \
        DataLoader(
            VideoDataset(params['dataset'], mode='train', clip_len=params['clip_len'], frame_sample_rate=params['frame_sample_rate']),
            batch_size=params['batch_size'], shuffle=True, num_workers=params['num_workers'])

    val_dataloader = \
        DataLoader(
            VideoDataset(params['dataset'], mode='validation', clip_len=params['clip_len'], frame_sample_rate=params['frame_sample_rate']),
            batch_size=params['batch_size'], shuffle=False, num_workers=params['num_workers'])

    print("load model")
    model = threestream.resnet50(class_num=params['num_classes']) #change the model here
    print("model", model)
    
    if params['pretrained'] is not None:
        pretrained_dict = torch.load(params['pretrained'], map_location='cpu')
        try:
            model_dict = model.module.state_dict()
        except AttributeError:
            model_dict = model.state_dict()
        pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
        print("load pretrain model")
        model_dict.update(pretrained_dict)
        model.load_state_dict(model_dict)
    
    model = model.cuda(params['gpu'][0])
    model = nn.DataParallel(model, device_ids=params['gpu'])  # multi-Gpu

    criterion = nn.CrossEntropyLoss().cuda()
    optimizer = optim.SGD(model.parameters(), lr=params['learning_rate'], momentum=params['momentum'], weight_decay=params['weight_decay'])
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=params['step'], gamma=0.1)

    model_save_dir = os.path.join(params['save_path'], cur_time)
    if not os.path.exists(model_save_dir):
        os.makedirs(model_save_dir)
    for epoch in range(params['epoch_num']):
        train(model, train_dataloader, epoch, criterion, optimizer, writer)
        if epoch % 2== 0:
            validation(model, val_dataloader, epoch, criterion, optimizer, writer)
        scheduler.step()
        if epoch % 1 == 0:
            checkpoint = os.path.join(model_save_dir,
                                      "clip_len_" + str(params['clip_len']) + "frame_sample_rate_" +str(params['frame_sample_rate'])+ "_checkpoint_" + str(epoch) + ".pth.tar")
            #torch.save(model.module.state_dict(), checkpoint)

    writer.close
예제 #2
0
def main():
    cudnn.benchmark = False
    cur_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time()))
    logdir = os.path.join(params.log, cur_time)
    if not os.path.exists(logdir):
        os.makedirs(logdir)

    writer = SummaryWriter(log_dir=logdir)

    print("Loading dataset")
    train_dataloader = \
        DataLoader(
            VideoDataset(params.dataset, mode='train', clip_len=params.clip_len, frame_sample_rate=params.frame_sample_rate),
            batch_size=params.batch_size, shuffle=True, num_workers=params.num_workers)

    val_dataloader = \
        DataLoader(
            VideoDataset(params.dataset, mode='validation', clip_len=params.clip_len, frame_sample_rate=params.frame_sample_rate),
            batch_size=params.batch_size, shuffle=False, num_workers=params.num_workers)

    print("load model")
    model = slowfastnet.resnet50(class_num=params.num_classes)

    if params.pretrained is not None:
        pretrained_dict = torch.load(params.pretrained, map_location='cpu')
        try:
            model_dict = model.module.state_dict()
        except AttributeError:
            model_dict = model.state_dict()
        pretrained_dict = {
            k: v
            for k, v in pretrained_dict.items() if k in model_dict
        }
        print("Load pretrain model")
        model_dict.update(pretrained_dict)
        model.load_state_dict(model_dict)

    model = model.cuda(params.gpu[0])
    model = nn.DataParallel(model, device_ids=params.gpu)  # multi-Gpu

    criterion = nn.CrossEntropyLoss().cuda(params.gpu[0])
    optimizer = optim.SGD(model.parameters(),
                          lr=params.learning_rate,
                          momentum=params.momentum,
                          weight_decay=params.weight_decay)
    scheduler = optim.lr_scheduler.StepLR(optimizer,
                                          step_size=params.step,
                                          gamma=0.1)

    model_save_dir = os.path.join(params.save_path, cur_time)
    if not os.path.exists(model_save_dir):
        os.makedirs(model_save_dir)

    best_loss = 1e6
    best_valid = dict(top1_acc=0., top5_acc=0., epoch=0)
    no_loss_decrease_count = 0

    for epoch in range(params.epoch_num):
        train_top1_acc, train_top5_acc, train_loss = train(
            model, train_dataloader, epoch, criterion, optimizer, writer)

        if (epoch + 1) % params.val_freq == 0:
            val_top1_acc, val_top5_acc, val_loss = validation(
                model, val_dataloader, epoch, criterion, writer)

            if val_top1_acc > best_valid['top1_acc']:
                best_valid['top1_acc'] = val_top1_acc
                best_valid['top5_acc'] = val_top5_acc
                best_valid['epoch'] = epoch

        if train_loss < best_loss:
            best_loss = train_loss
            no_loss_decrease_count = 0
        else:
            no_loss_decrease_count += 1
        if no_loss_decrease_count >= params.patience:
            print(
                f'Early stop on Epoch {epoch} with patience {params.patience}')
            write_exp_log(f'[{epoch}] Early stop')
            break

        scheduler.step()

        if epoch % 1 == 0:
            checkpoint = os.path.join(
                model_save_dir, "clip_len_" + str(params.clip_len) +
                "frame_sample_rate_" + str(params.frame_sample_rate) +
                "_checkpoint_" + str(epoch) + ".pth.tar")
            torch.save(model.module.state_dict(), checkpoint)

    print(
        f'Best Validated model was found on epoch {best_valid["epoch"]}:  Top1 acc: {best_valid["top1_acc"]}  Top5 acc: {best_valid["top5_acc"]}'
    )
    write_exp_log(
        f'Best model found on epoch {best_valid["epoch"]}:  Top1 acc: {best_valid["top1_acc"]}  Top5 acc: {best_valid["top5_acc"]}'
    )

    writer.close()
def bfp_quant(model_name, dataset_dir, num_classes, gpus, mantisa_bit, exp_bit, batch_size=1, 
                num_bins=8001, eps=0.0001, num_workers=2, num_examples=10, std=None, mean=None,
                resize=256, crop=224, exp_act=None, bfp_act_chnl=1, bfp_weight_chnl=1, bfp_quant=1,
                target_module_list=None, act_bins_factor=3, fc_bins_factor=4, is_online=0):
    # Setting up gpu environment
    os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(gpus)


    # Setting up dataload for evaluation
    valdir = os.path.join(dataset_dir, 'val')
    normalize = transforms.Normalize(mean=mean,
                                     std=std)

    
    train_dataloader = DataLoader(VideoDataset(dataset='ucf101', split='train',clip_len=16, model_name=model_name), batch_size=batch_size, shuffle=True, 
                            num_workers=4)
    val_dataloader   = DataLoader(VideoDataset(dataset='ucf101', split='val',  clip_len=16, model_name=model_name), batch_size=num_examples, num_workers=4)
    test_dataloader  = DataLoader(VideoDataset(dataset='ucf101', split='test', clip_len=16, model_name=model_name), batch_size=batch_size, num_workers=4)

    # # for collect intermediate data use
    # collect_loader = torch.utils.data.DataLoader(
    #     datasets.ImageFolder(valdir, transforms.Compose([
    #         transforms.Resize(resize),
    #         transforms.CenterCrop(crop),
    #         transforms.ToTensor(),
    #         normalize,
    #     ])),
    #     batch_size=num_examples, shuffle=False,
    #     num_workers=num_workers, pin_memory=True)
    # # for validate the bfp model use
    # val_loader = torch.utils.data.DataLoader(
    #     datasets.ImageFolder(valdir, transforms.Compose([
    #         transforms.Resize(resize),
    #         transforms.CenterCrop(crop),
    #         transforms.ToTensor(),
    #         normalize,
    #     ])),
    #     batch_size=batch_size, shuffle=False,
    #     num_workers=num_workers, pin_memory=True)

    if (bfp_quant == 1):
        # Loading the model
        model, _ = model_factory_3d.get_network(model_name, pretrained=True)
        # Insert the hook to record the intermediate result
        #target_module_list = [nn.BatchNorm2d,nn.Linear] # Insert hook after BN and FC
        model, intern_outputs = Stat_Collector.insert_hook(model, target_module_list)
        #model = nn.DataParallel(model)
        model.cuda()
        model.eval()
        

        # Collect the intermediate result while running number of examples
        logging.info("Collecting the statistics while running image examples....")
        images_statistc = torch.empty((1))
        with torch.no_grad():
            for i_batch, (images, lables) in enumerate(val_dataloader):
                images = images.cuda()
                outputs = model(images)
                #print(lables)
                _, predicted = torch.max(outputs.data, 1)
                predicted = predicted.cpu() # needs to verify if this line can be deleted  
                # Collect the input data
                image_shape = images.shape
                images_statistc = torch.reshape(images, 
                                        (image_shape[0], image_shape[1], image_shape[2], image_shape[3]*image_shape[4]))
                break
        
        # Deternmining the optimal exponent of activation and
        # Constructing the distribution for tensorboardX visualization
        logging.info("Determining the optimal exponent by minimizing the KL divergence....")
        start = time.time()
        opt_exp_act_list = []
        max_exp_act_list = []
        # For original input
        opt_exp, max_exp = Utils.find_exp_act_3d(images_statistc, mantisa_bit, exp_bit, group=3, eps=eps, bins_factor=act_bins_factor)
        opt_exp_act_list.append(opt_exp)
        max_exp_act_list.append(max_exp)
        sc_layer = []
        ds_sc_layer = []
        if model_name == "r3d_18":
            sc_layer = [2, 4, 9, 14, 19]
            ds_sc_layer = [6, 7, 11, 12, 16, 17]
        else:
            sc_layer = [2, 4, 6, 11, 13, 15, 20, 22, 24, 26, 28, 33, 35]
            ds_sc_layer = [8, 9, 17, 18, 30, 31]
        for i, intern_output in enumerate(intern_outputs):
            #Deternmining the optimal exponent by minimizing the KL_Divergence in channel-wise manner
            print ("i-th", i, "  shape:", intern_output.out_features.shape, " name:", intern_output.m)
            if (isinstance(intern_output.m, nn.Conv3d) or isinstance(intern_output.m, nn.BatchNorm3d)):
                if ((model_name=="r3d") and (i in sc_layer)):
                    intern_features1 = intern_output.out_features
                    intern_features2 = intern_outputs[i-2].out_features
                    intern_features = torch.cat((intern_features1, intern_features2), 0)
                    intern_features = torch.reshape(intern_features,
                                    (2*intern_shape[0], intern_shape[1], intern_shape[2], intern_shape[3]*intern_shape[4]))
                    opt_exp, max_exp = Utils.find_exp_act_3d(intern_features, mantisa_bit, exp_bit, 
                                                    group = bfp_act_chnl, eps=eps, bins_factor=act_bins_factor)
                    print ("i-th", i, "  length:", len(opt_exp))
                    opt_exp_act_list.append(opt_exp) ##changed
                    max_exp_act_list.append(max_exp)
                elif ((model_name=="r3d") and (i in ds_sc_layer)):
                    intern_features1 = intern_output.out_features
                    if ((i+1) in ds_sc_layer):
                        intern_features2 = intern_outputs[i+1].out_features
                    else:
                        continue # Use the same exp as previous layer
                    intern_features = torch.cat((intern_features1, intern_features2), 0)
                    intern_features = torch.reshape(intern_features,
                                    (2*intern_shape[0], intern_shape[1], intern_shape[2], intern_shape[3]*intern_shape[4]))
                    opt_exp, max_exp = Utils.find_exp_act_3d(intern_features, mantisa_bit, exp_bit, 
                                                    group = bfp_act_chnl, eps=eps, bins_factor=act_bins_factor)
                    print ("i-th", i, "  length:", len(opt_exp))
                    opt_exp_act_list.append(opt_exp) ##changed
                    max_exp_act_list.append(max_exp)
                else:
                    intern_shape = intern_output.out_features.shape
                    print (intern_shape)
                    intern_features = torch.reshape(intern_output.out_features,
                                    (intern_shape[0], intern_shape[1], intern_shape[2], intern_shape[3]*intern_shape[4]))
                    opt_exp, max_exp = Utils.find_exp_act_3d(intern_features, mantisa_bit, exp_bit, 
                                                    group = bfp_act_chnl, eps=eps, bins_factor=act_bins_factor)
                    print ("i-th", i, "  length:", len(opt_exp))
                    opt_exp_act_list.append(opt_exp) ##changed
                    max_exp_act_list.append(max_exp)
            elif (isinstance(intern_output.m, nn.Linear)):
                intern_shape = intern_output.out_features.shape
                opt_exp, max_exp = Utils.find_exp_fc(intern_output.out_features, mantisa_bit, exp_bit, block_size = intern_shape[1], eps=eps, bins_factor=fc_bins_factor)
                #print ("shape of fc exponent:", np.shape(opt_exp))
                opt_exp_act_list.append(max_exp)
                max_exp_act_list.append(max_exp)
            else:
                pass
                '''
                intern_shape = intern_output.in_features[0].shape
                intern_features = torch.reshape(intern_output.in_features[0], 
                                    (intern_shape[0], intern_shape[1], intern_shape[2]*intern_shape[3]))
                opt_exp, max_exp = Utils.find_exp_act(intern_features, mantisa_bit, exp_bit, 
                                                    group = bfp_act_chnl, eps=eps, bins_factor=act_bins_factor)
                opt_exp_act_list.append(opt_exp)
                max_exp_act_list.append(max_exp)
                '''
                
            #logging.info("The internal shape: %s" % ((str)(intern_output.out_features.shape)))
        end = time.time()
        logging.info("It took %f second to determine the optimal shared exponent for each block." % ((float)(end-start)))
        logging.info("The shape of collect exponents: %s" % ((str)(np.shape(opt_exp_act_list))))

        # Building a BFP model by insert BFPAct and BFPWeiht based on opt_exp_act_list
        torch.cuda.empty_cache() 
        if (exp_act=='kl'):
            exp_act_list = opt_exp_act_list
        else:
            exp_act_list = max_exp_act_list
    else:
       exp_act_list = None 
    bfp_model, weight_exp_list = model_factory_3d.get_network(model_name, pretrained=True, bfp=(bfp_quant==1), group=bfp_weight_chnl, mantisa_bit=mantisa_bit, 
                exp_bit=exp_bit, opt_exp_act_list=exp_act_list, is_online=is_online, exp_act=exp_act)

    dynamic_bfp_model, dynamic_weight_exp_list = model_factory_3d.get_network(model_name, pretrained=True, bfp=(bfp_quant==1), group=bfp_weight_chnl, mantisa_bit=mantisa_bit, 
                exp_bit=exp_bit, opt_exp_act_list=exp_act_list, is_online=1, exp_act=exp_act)

    #torch.cuda.empty_cache() 
    confusion_matrix = torch.zeros(num_classes, num_classes)
    dynamic_confusion_matrix = torch.zeros(num_classes, num_classes)
    logging.info("Evaluation Block Floating Point quantization....")
    correct = 0
    total = 0
    dynamic_correct = 0
    dynamic_total = 0
    if ((model_name != "br_mobilenetv2") or (model_name != "mobilenetv2")):
        bfp_model = nn.DataParallel(bfp_model)
    bfp_model.cuda()
    bfp_model.eval()
    if ((model_name != "br_mobilenetv2") or (model_name != "mobilenetv2")):
        dynamic_bfp_model = nn.DataParallel(dynamic_bfp_model)
    dynamic_bfp_model.cuda()
    dynamic_bfp_model.eval()
    with torch.no_grad():
        for i_batch, (images, lables) in enumerate(test_dataloader):
            images = images.cuda()
            outputs = bfp_model(images)
            probs = nn.Softmax(dim=1)(outputs)
            _, predicted = torch.max(probs, 1)
            predicted = predicted.cpu()
            dynamic_outputs = dynamic_bfp_model(images)
            dynamic_probs = nn.Softmax(dim=1)(dynamic_outputs)
            _,dynamic_predicted = torch.max(dynamic_probs, 1)
            dynamic_predicted = dynamic_predicted.cpu()
            for t, p in zip(lables.view(-1), predicted.view(-1)):
              confusion_matrix[t.long(), p.long()] += 1
            for t, p in zip(lables.view(-1), dynamic_predicted.view(-1)):
              dynamic_confusion_matrix[t.long(), p.long()] += 1
            total += lables.size(0)
            correct += (predicted == lables).sum().item()
            logging.info("Current images: %d" % (total))
            #if (total > 2000):
            #    break
    logging.info("Total: %d, Accuracy: %f " % (total, float(correct / total)))
    logging.info("Floating conv weight and fc(act and weight), act bins_factor is %d,fc bins_factor is %d, exp_opt for act is %s, act group is %d"%(act_bins_factor, fc_bins_factor, exp_act, bfp_act_chnl))    
    print ("Per class accuracy:", (confusion_matrix.diag()/confusion_matrix.sum(1)))
    print ("Per class dynamic accuracy:", (dynamic_confusion_matrix.diag()/dynamic_confusion_matrix.sum(1)))
    torch.save(confusion_matrix, "static_bfp_r3d.pt")
    torch.save(dynamic_confusion_matrix, "dynamic_bfp_r3d.pt")
    writer.close()
예제 #4
0
def bfp_quant(model_name,
              dataset_dir,
              num_classes,
              gpus,
              mantisa_bit,
              exp_bit,
              batch_size=1,
              num_bins=8001,
              eps=0.0001,
              num_workers=2,
              num_examples=10,
              std=None,
              mean=None,
              resize=256,
              crop=224,
              exp_act=None,
              bfp_act_chnl=1,
              bfp_weight_chnl=1,
              bfp_quant=1,
              target_module_list=None,
              act_bins_factor=3,
              fc_bins_factor=4,
              is_online=0):
    # Setting up gpu environment
    os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(gpus)

    # Setting up dataload for evaluation
    valdir = os.path.join(dataset_dir, 'val')
    normalize = transforms.Normalize(mean=mean, std=std)

    train_dataloader = DataLoader(VideoDataset(dataset='ucf101',
                                               split='train',
                                               clip_len=16,
                                               model_name=model_name),
                                  batch_size=batch_size,
                                  shuffle=True,
                                  num_workers=4)
    val_dataloader = DataLoader(VideoDataset(dataset='ucf101',
                                             split='val',
                                             clip_len=16,
                                             model_name=model_name),
                                batch_size=num_examples,
                                num_workers=4)
    test_dataloader = DataLoader(VideoDataset(dataset='ucf101',
                                              split='test',
                                              clip_len=16,
                                              model_name=model_name),
                                 batch_size=batch_size,
                                 num_workers=4)

    # # for collect intermediate data use
    # collect_loader = torch.utils.data.DataLoader(
    #     datasets.ImageFolder(valdir, transforms.Compose([
    #         transforms.Resize(resize),
    #         transforms.CenterCrop(crop),
    #         transforms.ToTensor(),
    #         normalize,
    #     ])),
    #     batch_size=num_examples, shuffle=False,
    #     num_workers=num_workers, pin_memory=True)
    # # for validate the bfp model use
    # val_loader = torch.utils.data.DataLoader(
    #     datasets.ImageFolder(valdir, transforms.Compose([
    #         transforms.Resize(resize),
    #         transforms.CenterCrop(crop),
    #         transforms.ToTensor(),
    #         normalize,
    #     ])),
    #     batch_size=batch_size, shuffle=False,
    #     num_workers=num_workers, pin_memory=True)

    exp_act_list = None
    lq_model, weight_exp_list = model_factory_3d.get_network(
        model_name,
        pretrained=True,
        bfp=(bfp_quant == 1),
        group=bfp_weight_chnl,
        mantisa_bit=mantisa_bit,
        exp_bit=exp_bit,
        opt_exp_act_list=exp_act_list,
        is_online=is_online,
        exp_act=exp_act)

    lq_model.eval()
    lq_model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
    lq_model.fuse_model()
    lq_model_prepared = torch.quantization.prepare_qat(lq_model)
    #torch.cuda.empty_cache()

    inpt_fp32 = torch.randn(1, 3, 16, 112, 112)
    lq_model_prepared(inpt_fp32)
    lq_model_8bit = torch.quantization.convert(lq_model_prepared)

    logging.info("Evaluating linear-quant model....")
    correct = 0
    total = 0

    lq_model_8bit.cuda()
    lq_model_8bit.eval()
    with torch.no_grad():
        for i_batch, (images, lables) in enumerate(test_dataloader):
            images = images.cuda()
            outputs = lq_model_8bit(images)
            #outputs = model(images)
            probs = nn.Softmax(dim=1)(outputs)
            _, predicted = torch.max(probs, 1)
            predicted = predicted.cpu()
            total += lables.size(0)
            correct += (predicted == lables).sum().item()
            logging.info("Current images: %d" % (total))
            #if (total > 2000):
            #    break
    logging.info("Total: %d, Accuracy: %f " % (total, float(correct / total)))
    logging.info(
        "Floating conv weight and fc(act and weight), act bins_factor is %d,fc bins_factor is %d, exp_opt for act is %s, act group is %d"
        % (act_bins_factor, fc_bins_factor, exp_act, bfp_act_chnl))
    writer.close()
예제 #5
0
파일: train.py 프로젝트: Zhicaiwww/SFnet
def main():
    cudnn.benchmark = False
    cur_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time()))
    logdir = os.path.join(params['log'], cur_time)
    if not os.path.exists(logdir):
        os.makedirs(logdir)

    writer = SummaryWriter(log_dir=logdir)

    if params['train_val_Shuffle']:
        with open(params['tagg_info_txt'], 'r', encoding='utf-8') as f:
            lines = f.readlines()
            random.shuffle(lines)
            train_nums = 0.9 * len(lines)
            out_file_train = open(params['filename_imglist_train'],
                                  'w',
                                  encoding='utf-8')
            out_file_val = open(params['filename_imglist_val'],
                                'w',
                                encoding='utf-8')
            try:
                for index, line in tqdm.tqdm(enumerate(lines),
                                             total=len(lines)):
                    out_file = out_file_train if index < train_nums else out_file_val

                    line = line.strip()
                    if len(line.split('\t')) != 2: return
                    path, label_strings = line.split('\t')
                    vid = os.path.basename(path)
                    out_file.write(vid + '\t' + label_strings + '\n')
            except Exception as e:
                print(e)
    else:
        if os.path.exists(params['filename_imglist_train']) and os.path.exists(
                params['filename_imglist_val']):
            print("loading existing train and val splits!!")
        else:
            return 0

    print("Loading dataset")

    train_dataloader = \
        DataLoader(
            VideoDataset(params['dataset'], params['filename_imglist_train'], params['label_id_file'], mode='train', clip_len=params['clip_len'], frame_sample_rate=params['frame_sample_rate']),
            batch_size=params['batch_size'], shuffle=True, num_workers=params['num_workers'])

    val_dataloader = \
        DataLoader(
            VideoDataset(params['dataset'], params['filename_imglist_val'],params['label_id_file'], mode='validation', clip_len=params['clip_len'], frame_sample_rate=params['frame_sample_rate']),
            batch_size=params['batch_size'], shuffle=False, num_workers=params['num_workers'])

    print("load model")
    model = slowfastnet.resnet50(class_num=params['num_classes'])

    if params['pretrained'] is not None:
        pretrained_dict = torch.load(params['pretrained'], map_location='cpu')
        try:
            model_dict = model.module.state_dict()
        except AttributeError:
            model_dict = model.state_dict()
        pretrained_dict = {
            k: v
            for k, v in pretrained_dict.items() if k in model_dict
        }
        print("load pretrain model")
        model_dict.update(pretrained_dict)
        model.load_state_dict(model_dict)

    model = model.cuda(params['gpu'][0])
    model = nn.DataParallel(model, device_ids=params['gpu'])  # multi-Gpu

    criterion = loss.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(),
                          lr=params['learning_rate'],
                          momentum=params['momentum'],
                          weight_decay=params['weight_decay'])
    scheduler = optim.lr_scheduler.StepLR(optimizer,
                                          step_size=params['step'],
                                          gamma=0.1)

    model_save_dir = os.path.join(params['save_path'], cur_time)
    if not os.path.exists(model_save_dir):
        os.makedirs(model_save_dir)

    gamma = params['gamma']
    for epoch in range(params['epoch_num']):
        if params['train_rule'] == 'DRW':
            idx = epoch // 30
            # idx = 1
            betas = [0, 0.99]
            cls_num_list = [100] * params['num_classes']
            effective_num = 1.0 - np.power(betas[idx], cls_num_list)
            per_cls_weights = (1.0 - betas[idx]) / np.array(effective_num)
            per_cls_weights = per_cls_weights / np.sum(per_cls_weights) * len(
                cls_num_list)
            per_cls_weights = torch.FloatTensor(per_cls_weights)
        else:
            per_cls_weights = torch.ones(1)
        print(f"per_cls_weights{per_cls_weights}")
        train(train_dataloader, model, criterion, optimizer, epoch,
              per_cls_weights, gamma, writer)
        if epoch % 2 == 0:
            gap = validate2(val_dataloader,
                            model,
                            criterion,
                            epoch,
                            tfwriter=writer)

            is_best = gap > best_gap
            best_gap = max(gap, best_gap)
            writer.add_scalar('acc/val_gap_best', best_gap, epoch)
            output_best = 'Best val gap: %.3f\n' % (best_gap)
            print(output_best)
            if is_best and gap > 0.75:
                checkpoint = os.path.join(
                    model_save_dir, "clip_len_" + str(params['clip_len']) +
                    "frame_sample_rate_" + str(params['frame_sample_rate']) +
                    "_checkpoint_" + str(epoch) + "_gap" + str(gap) +
                    ".pth.tar")
                torch.save(model.module.state_dict(), checkpoint)
        scheduler.step()
    writer.close