Exemple #1
0
def train_model(sym_net,
                model_prefix,
                dataset,
                input_conf,
                clip_length=16,
                train_frame_interval=2,
                resume_epoch=-1,
                batch_size=4,
                save_frequency=1,
                lr_base=0.01,
                lr_factor=0.1,
                lr_steps=[400000, 800000],
                end_epoch=1000,
                distributed=False,
                fine_tune=False,
                **kwargs):

    assert torch.cuda.is_available(), "Currently, we only support CUDA version"

    # data iterator
    iter_seed = torch.initial_seed() + 100 + max(0, resume_epoch) * 100
    train_iter = iter_fac.creat(name=dataset,
                                batch_size=batch_size,
                                clip_length=clip_length,
                                train_interval=train_frame_interval,
                                mean=input_conf['mean'],
                                std=input_conf['std'],
                                seed=iter_seed)
    # wapper (dynamic model)
    net = model(
        net=sym_net,
        criterion=nn.CrossEntropyLoss().cuda(),
        model_prefix=model_prefix,
        step_callback_freq=50,
        save_checkpoint_freq=save_frequency,
        opt_batch_size=batch_size,
    )
    net.net.cuda()

    # config optimization
    param_base_layers = []
    param_new_layers = []
    name_base_layers = []
    for name, param in net.net.named_parameters():
        if fine_tune:
            if ('classifier' in name) or ('fc' in name):
                param_new_layers.append(param)
            else:
                param_base_layers.append(param)
                name_base_layers.append(name)
        else:
            param_new_layers.append(param)

    if name_base_layers:
        out = "[\'" + '\', \''.join(name_base_layers) + "\']"
        logging.info(
            "Optimizer:: >> recuding the learning rate of {} params: {}".
            format(
                len(name_base_layers),
                out if len(out) < 300 else out[0:150] + " ... " + out[-150:]))

    net.net = torch.nn.DataParallel(net.net).cuda()

    optimizer = torch.optim.SGD([{
        'params': param_base_layers,
        'lr_mult': 0.2
    }, {
        'params': param_new_layers,
        'lr_mult': 1.0
    }],
                                lr=lr_base,
                                momentum=0.9,
                                weight_decay=0.0001,
                                nesterov=True)

    # load params from pretrained 3d network
    if resume_epoch > 0:
        logging.info("Initializer:: resuming model from previous training")

    # resume training: model and optimizer
    if resume_epoch < 0:
        epoch_start = 0
        step_counter = 0
    else:
        net.load_checkpoint(epoch=resume_epoch, optimizer=optimizer)
        epoch_start = resume_epoch
        step_counter = epoch_start * train_iter.__len__()

    # set learning rate scheduler
    num_worker = dist.get_world_size() if torch.distributed.is_initialized(
    ) else 1
    lr_scheduler = MultiFactorScheduler(
        base_lr=lr_base,
        steps=[int(x / (batch_size * num_worker)) for x in lr_steps],
        factor=lr_factor,
        step_counter=step_counter)
    # define evaluation metric
    metrics = metric.MetricList(
        metric.Loss(name="loss-ce"),
        metric.Accuracy(name="top1", topk=1),
        metric.Accuracy(name="top5", topk=5),
    )

    net.fit(
        train_iter=train_iter,
        optimizer=optimizer,
        lr_scheduler=lr_scheduler,
        metrics=metrics,
        epoch_start=epoch_start,
        epoch_end=end_epoch,
    )
Exemple #2
0
def train_model(net_name,
                sym_net,
                model_prefix,
                dataset,
                input_conf,
                modality='rgb',
                split=1,
                clip_length=16,
                train_frame_interval=2,
                val_frame_interval=2,
                resume_epoch=-1,
                batch_size=4,
                save_frequency=1,
                lr_base=0.01,
                lr_base2=0.01,
                lr_d=None,
                lr_factor=0.1,
                lr_steps=[400000, 800000],
                end_epoch=1000,
                distributed=False,
                pretrained_3d=None,
                fine_tune=False,
                iter_size=1,
                optim='sgd',
                accumulate=True,
                ds_factor=16,
                epoch_thre=1,
                score_dir=None,
                mv_minmaxnorm=False,
                mv_loadimg=False,
                detach=False,
                adv=0,
                new_classifier=False,
                **kwargs):

    assert torch.cuda.is_available(), "Currently, we only support CUDA version"
    torch.multiprocessing.set_sharing_strategy('file_system')
    import resource
    rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
    resource.setrlimit(resource.RLIMIT_NOFILE, (2048, rlimit[1]))
    # data iterator
    iter_seed = torch.initial_seed() \
                + (torch.distributed.get_rank() * 10 if distributed else 100) \
                + max(0, resume_epoch) * 100

    train_iter, eval_iter = iterator_factory.creat(
        name=dataset,
        batch_size=batch_size,
        clip_length=clip_length,
        train_interval=train_frame_interval,
        val_interval=val_frame_interval,
        mean=input_conf['mean'],
        std=input_conf['std'],
        seed=iter_seed,
        modality=modality,
        split=split,
        net_name=net_name,
        accumulate=accumulate,
        ds_factor=ds_factor,
        mv_minmaxnorm=mv_minmaxnorm,
        mv_loadimg=mv_loadimg)
    #define an instance of class model
    net = model(
        net=sym_net,
        criterion=torch.nn.CrossEntropyLoss().cuda(),
        model_prefix=model_prefix,
        step_callback_freq=50,
        save_checkpoint_freq=save_frequency,
        opt_batch_size=batch_size,  # optional
        criterion2=torch.nn.MSELoss().cuda()
        if modality == 'flow+mp4' else None,
        criterion3=torch.nn.CrossEntropyLoss().cuda() if adv > 0. else None,
        adv=adv,
    )
    net.net.cuda()
    print(torch.cuda.current_device(), torch.cuda.device_count())
    # config optimization
    param_base_layers = []
    param_new_layers = []
    name_base_layers = []
    params_gf = []
    params_d = []
    for name, param in net.net.named_parameters():
        if modality == 'flow+mp4':
            if name.startswith('gen_flow_model'):
                params_gf.append(param)
            elif name.startswith('discriminator'):
                params_d.append(param)
            else:
                if (name.startswith('conv3d_0c_1x1')
                        or name.startswith('classifier')):
                    #if name.startswith('classifier'):
                    param_new_layers.append(param)
                else:
                    param_base_layers.append(param)
                    name_base_layers.append(name)
            #else:
            #    #print(name)
            #    param_new_layers.append(param)
        else:
            if fine_tune:
                if name.startswith('classifier') or name.startswith(
                        'conv3d_0c_1x1'):
                    #if name.startswith('classifier'):
                    param_new_layers.append(param)
                else:
                    param_base_layers.append(param)
                    name_base_layers.append(name)
            else:
                param_new_layers.append(param)
    if modality == 'flow+mp4':
        if fine_tune:
            lr_mul = 0.2
        else:
            lr_mul = 0.5
    else:
        lr_mul = 0.2
    #print(params_d)
    if name_base_layers:
        out = "[\'" + '\', \''.join(name_base_layers) + "\']"
        logging.info(
            "Optimizer:: >> recuding the learning rate of {} params: {} by factor {}"
            .format(
                len(name_base_layers),
                out if len(out) < 300 else out[0:150] + " ... " + out[-150:],
                lr_mul))
    if net_name == 'I3D':
        weight_decay = 0.0001
    else:
        raise ValueError('UNKOWN net_name', net_name)
    logging.info("Train_Model:: weight_decay: `{}'".format(weight_decay))
    if distributed:
        net.net = torch.nn.parallel.DistributedDataParallel(net.net).cuda()
    else:
        net.net = torch.nn.DataParallel(net.net).cuda()

    if optim == 'adam':
        optimizer = torch.optim.Adam([{
            'params': param_base_layers,
            'lr_mult': lr_mul
        }, {
            'params': param_new_layers,
            'lr_mult': 1.0
        }],
                                     lr=lr_base,
                                     weight_decay=weight_decay)
        optimizer_2 = torch.optim.Adam([{
            'params': param_base_layers,
            'lr_mult': lr_mul
        }, {
            'params': param_new_layers,
            'lr_mult': 1.0
        }],
                                       lr=lr_base2,
                                       weight_decay=weight_decay)
    else:
        optimizer = torch.optim.SGD([{
            'params': param_base_layers,
            'lr_mult': lr_mul
        }, {
            'params': param_new_layers,
            'lr_mult': 1.0
        }],
                                    lr=lr_base,
                                    momentum=0.9,
                                    weight_decay=weight_decay,
                                    nesterov=True)
        optimizer_2 = torch.optim.SGD([{
            'params': param_base_layers,
            'lr_mult': lr_mul
        }, {
            'params': param_new_layers,
            'lr_mult': 1.0
        }],
                                      lr=lr_base2,
                                      momentum=0.9,
                                      weight_decay=weight_decay,
                                      nesterov=True)
    if adv > 0.:
        optimizer_3 = torch.optim.Adam(params_d,
                                       lr=lr_base,
                                       weight_decay=weight_decay,
                                       eps=0.001)
    else:
        optimizer_3 = None
    if modality == 'flow+mp4':
        if optim == 'adam':
            optimizer_mse = torch.optim.Adam(params_gf,
                                             lr=lr_base,
                                             weight_decay=weight_decay,
                                             eps=1e-08)
            optimizer_mse_2 = torch.optim.Adam(params_gf,
                                               lr=lr_base2,
                                               weight_decay=weight_decay,
                                               eps=0.001)
        else:
            optimizer_mse = torch.optim.SGD(params_gf,
                                            lr=lr_base,
                                            momentum=0.9,
                                            weight_decay=weight_decay,
                                            nesterov=True)
            optimizer_mse_2 = torch.optim.SGD(params_gf,
                                              lr=lr_base2,
                                              momentum=0.9,
                                              weight_decay=weight_decay,
                                              nesterov=True)
    else:
        optimizer_mse = None
        optimizer_mse_2 = None
    # load params from pretrained 3d network
    if pretrained_3d and not pretrained_3d == 'False':
        if resume_epoch < 0:
            assert os.path.exists(pretrained_3d), "cannot locate: `{}'".format(
                pretrained_3d)
            logging.info(
                "Initializer:: loading model states from: `{}'".format(
                    pretrained_3d))
            if net_name == 'I3D':
                checkpoint = torch.load(pretrained_3d)
                keys = list(checkpoint.keys())
                state_dict = {}
                for name in keys:
                    state_dict['module.' + name] = checkpoint[name]
                del checkpoint
                net.load_state(state_dict, strict=False)
                if new_classifier:
                    checkpoint = torch.load(
                        './network/pretrained/model_flow.pth')
                    keys = list(checkpoint.keys())
                    state_dict = {}
                    for name in keys:
                        state_dict['module.' + name] = checkpoint[name]
                    del checkpoint
                    net.load_state(state_dict, strict=False)
            else:
                checkpoint = torch.load(pretrained_3d)
                net.load_state(checkpoint['state_dict'], strict=False)
        else:
            logging.info(
                "Initializer:: skip loading model states from: `{}'" +
                ", since it's going to be overwrited by the resumed model".
                format(pretrained_3d))

    # resume training: model and optimizer
    if resume_epoch < 0:
        epoch_start = 0
        step_counter = 0
    else:
        net.load_checkpoint(epoch=resume_epoch,
                            optimizer=optimizer,
                            optimizer_mse=optimizer_mse)
        epoch_start = resume_epoch
        step_counter = epoch_start * train_iter.__len__()

    # set learning rate scheduler
    num_worker = dist.get_world_size() if torch.distributed._initialized else 1
    lr_scheduler = MultiFactorScheduler(
        base_lr=lr_base,
        steps=[int(x / (batch_size * num_worker)) for x in lr_steps],
        factor=lr_factor,
        step_counter=step_counter)
    if modality == 'flow+mp4':
        lr_scheduler2 = MultiFactorScheduler(
            base_lr=lr_base2,
            steps=[int(x / (batch_size * num_worker)) for x in lr_steps],
            factor=lr_factor,
            step_counter=step_counter)
        if lr_d == None:
            lr_scheduler3 = MultiFactorScheduler(
                base_lr=lr_d,
                steps=[int(x / (batch_size * num_worker)) for x in lr_steps],
                factor=lr_factor,
                step_counter=step_counter)
        else:
            print("_____________", lr_d)
            lr_scheduler3 = MultiFactorScheduler(
                base_lr=lr_d,
                steps=[int(x / (batch_size * num_worker)) for x in lr_steps],
                factor=lr_factor,
                step_counter=step_counter)
    else:
        lr_scheduler2 = None
        lr_scheduler3 = None
    # define evaluation metric
    metrics_D = None
    if modality == 'flow+mp4':
        metrics = metric.MetricList(
            metric.Loss(name="loss-ce"),
            metric.Loss(name="loss-mse"),
            metric.Accuracy(name="top1", topk=1),
            metric.Accuracy(name="top5", topk=5),
        )
        if adv > 0:
            metrics_D = metric.MetricList(metric.Loss(name="classi_D"),
                                          metric.Loss(name="adv_D"))

    else:
        metrics = metric.MetricList(
            metric.Loss(name="loss-ce"),
            metric.Accuracy(name="top1", topk=1),
            metric.Accuracy(name="top5", topk=5),
        )
    # enable cudnn tune
    cudnn.benchmark = True
    net.fit(train_iter=train_iter,
            eval_iter=eval_iter,
            optimizer=optimizer,
            lr_scheduler=lr_scheduler,
            metrics=metrics,
            epoch_start=epoch_start,
            epoch_end=end_epoch,
            iter_size=iter_size,
            optimizer_mse=optimizer_mse,
            optimizer_2=optimizer_2,
            optimizer_3=optimizer_3,
            optimizer_mse_2=optimizer_mse_2,
            lr_scheduler2=lr_scheduler2,
            lr_scheduler3=lr_scheduler3,
            metrics_D=metrics_D,
            epoch_thre=epoch_thre,
            score_dir=score_dir,
            detach=detach)
Exemple #3
0
def train_model(sym_net, model_prefix, dataset, input_conf,
                clip_length=16, train_frame_interval=2, val_frame_interval=2,
                resume_epoch=-1, batch_size=4, save_frequency=1,
                lr_base=0.01, lr_factor=0.1, lr_steps=[400000, 800000],
                end_epoch=1000, distributed=False, 
                pretrained_3d=None, fine_tune=False,
                load_from_frames=False, use_flow=False, triplet_loss=False,
                **kwargs):

    assert torch.cuda.is_available(), "Currently, we only support CUDA version"

    # data iterator
    iter_seed = torch.initial_seed() \
                + (torch.distributed.get_rank() * 10 if distributed else 100) \
                + max(0, resume_epoch) * 100
    train_iter, eval_iter = iterator_factory.creat(name=dataset,
                                                   batch_size=batch_size,
                                                   clip_length=clip_length,
                                                   train_interval=train_frame_interval,
                                                   val_interval=val_frame_interval,
                                                   mean=input_conf['mean'],
                                                   std=input_conf['std'],
                                                   seed=iter_seed,
                                                   load_from_frames=load_from_frames,
                                                   use_flow=use_flow)
    # wapper (dynamic model)
    if use_flow:
        class LogNLLLoss(torch.nn.Module):
            def __init__(self):
                super(LogNLLLoss, self).__init__()
                self.loss = torch.nn.NLLLoss()

            def forward(self, output, target):
                output = torch.log(output)
                loss = self.loss(output, target)
                return loss
        # criterion = LogNLLLoss().cuda()
        criterion = torch.nn.CrossEntropyLoss().cuda()
    elif triplet_loss:
        logging.info("Using triplet loss")
        criterion=torch.nn.MarginRankingLoss().cuda()
    else:
        criterion = torch.nn.CrossEntropyLoss().cuda()
    net = model(net=sym_net,
                criterion=criterion,
                triplet_loss=triplet_loss,
                model_prefix=model_prefix,
                step_callback_freq=50,
                save_checkpoint_freq=save_frequency,
                opt_batch_size=batch_size, # optional
                )
    net.net.cuda()

    # config optimization
    param_base_layers = []
    param_new_layers = []
    name_base_layers = []
    for name, param in net.net.named_parameters():
        if fine_tune:
            # if name.startswith('classifier'):
            if 'classifier' in name or 'fc' in name:
                param_new_layers.append(param)
            else:
                param_base_layers.append(param)
                name_base_layers.append(name)
        else:
            param_new_layers.append(param)

    if name_base_layers:
        out = "[\'" + '\', \''.join(name_base_layers) + "\']"
        logging.info("Optimizer:: >> recuding the learning rate of {} params: {}".format(len(name_base_layers),
                     out if len(out) < 300 else out[0:150] + " ... " + out[-150:]))

    if distributed:
        net.net = torch.nn.parallel.DistributedDataParallel(net.net).cuda()
    else:
        net.net = torch.nn.DataParallel(net.net).cuda()

    optimizer = torch.optim.SGD([{'params': param_base_layers, 'lr_mult': 0.2},
                                 {'params': param_new_layers, 'lr_mult': 1.0}],
                                lr=lr_base,
                                momentum=0.9,
                                weight_decay=0.0001,
                                nesterov=True)

    # load params from pretrained 3d network
    if pretrained_3d:
        if resume_epoch < 0:
            if os.path.exists(pretrained_3d):
                # assert os.path.exists(pretrained_3d), "cannot locate: '{}'".format(pretrained_3d)
                logging.info("Initializer:: loading model states from: `{}'".format(pretrained_3d))
                checkpoint = torch.load(pretrained_3d)
                net.load_state(checkpoint['state_dict'], mode='ada')
            else:
                logging.warning("cannot locate: '{}'".format(pretrained_3d))
        else:
            logging.info("Initializer:: skip loading model states from: `{}'"
                + ", since it's going to be overwrited by the resumed model".format(pretrained_3d))

    # resume training: model and optimizer
    if resume_epoch < 0:
        epoch_start = 0
        step_counter = 0
    else:
        net.load_checkpoint(epoch=resume_epoch, optimizer=optimizer)
        epoch_start = resume_epoch
        step_counter = epoch_start * train_iter.__len__()

    # set learning rate scheduler
    num_worker = 1
    lr_scheduler = MultiFactorScheduler(base_lr=lr_base,
                                        steps=[int(x/(batch_size*num_worker)) for x in lr_steps],
                                        factor=lr_factor,
                                        step_counter=step_counter)
    # define evaluation metric
    if triplet_loss:
        metrics = metric.MetricList(metric.Loss(name="loss-triplet"),
                                    metric.TripletAccuracy(name="acc"), )
    else:
        metrics = metric.MetricList(metric.Loss(name="loss-ce"),
                                    metric.Accuracy(name="top1", topk=1),
                                    metric.Accuracy(name="top5", topk=5), )
    # enable cudnn tune
    cudnn.benchmark = True

    net.fit(train_iter=train_iter,
            eval_iter=eval_iter,
            optimizer=optimizer,
            lr_scheduler=lr_scheduler,
            metrics=metrics,
            epoch_start=epoch_start,
            epoch_end=end_epoch,)
Exemple #4
0
def train_model(Hash_center, sym_net, model_prefix, dataset, input_conf, hash_bit,
                clip_length=16, train_frame_interval=2, val_frame_interval=2,
                resume_epoch=-1, batch_size=4, save_frequency=1,
                lr_base=0.01, lr_factor=0.1, lr_steps=[400000, 800000],
                end_epoch=1000, distributed=False, 
                pretrained_3d=None, fine_tune=False,
                **kwargs):

    assert torch.cuda.is_available(), "Currently, we only support CUDA version"

    # data iterator
    iter_seed = torch.initial_seed()  \
                + (torch.distributed.get_rank() * 10 if distributed else 100) \
                + max(0, resume_epoch) * 100
    train_iter, eval_iter = iterator_factory.creat(name=dataset,
                                                   batch_size=batch_size,
                                                   clip_length=clip_length,
                                                   train_interval=train_frame_interval,
                                                   val_interval=val_frame_interval,
                                                   mean=input_conf['mean'],
                                                   std=input_conf['std'],
                                                   seed=iter_seed)
    print(len(train_iter))
    print(len(eval_iter))
    # wapper (dynamic model)
    net = model(net=sym_net,
                criterion=torch.nn.BCELoss().cuda(),
                model_prefix=model_prefix,
                step_callback_freq=50,
                save_checkpoint_freq=save_frequency,
                opt_batch_size=batch_size, # optional
                dataset=dataset,  # dataset name
                hash_bit=hash_bit,
                )
    net.net.cuda()

    # config optimization
    param_base_layers = []
    param_new_layers = []
    name_base_layers = []
    for name, param in net.net.named_parameters():
        if fine_tune:
            #print(f'fine tune {fine_tune}')
            if name.startswith('hash'):
                param_new_layers.append(param)
            else:
                param_base_layers.append(param)
                name_base_layers.append(name)
        else:
            param_new_layers.append(param)

    if name_base_layers:
        out = "[\'" + '\', \''.join(name_base_layers) + "\']"
        logging.info("Optimizer:: >> recuding the learning rate of {} params: {}".format(len(name_base_layers),
                     out if len(out) < 300 else out[0:150] + " ... " + out[-150:]))

    if distributed:
        net.net = torch.nn.parallel.DistributedDataParallel(net.net).cuda()
    else:
        net.net = torch.nn.DataParallel(net.net).cuda()

    optimizer = torch.optim.SGD([{'params': param_base_layers, 'lr_mult': 0.2},
                                 {'params': param_new_layers, 'lr_mult': 1.0}],
                                lr=lr_base,
                                momentum=0.9,
                                weight_decay=0.0001,
                                nesterov=True)

    # load params from pretrained 3d network
    if pretrained_3d:
        if resume_epoch < 0:
            assert os.path.exists(pretrained_3d), "cannot locate: `{}'".format(pretrained_3d)
            logging.info("Initializer:: loading model states from: `{}'".format(pretrained_3d))
            checkpoint = torch.load(pretrained_3d)
            net.load_state(checkpoint['state_dict'], strict=False)
        else:
            logging.info("Initializer:: skip loading model states from: `{}'"
                + ", since it's going to be overwrited by the resumed model".format(pretrained_3d))

    # resume training: model and optimizer
    if resume_epoch < 0:
        epoch_start = 0
        step_counter = 0
    else:
        net.load_checkpoint(epoch=resume_epoch, optimizer=optimizer)
        epoch_start = resume_epoch
        step_counter = epoch_start * train_iter.__len__()

    # set learning rate scheduler
    num_worker = dist.get_world_size() if torch.distributed._initialized else 1
    lr_scheduler = MultiFactorScheduler(base_lr=lr_base,
                                        steps=[int(x/(batch_size*num_worker)) for x in lr_steps],
                                        factor=lr_factor,
                                        step_counter=step_counter)
    # define evaluation metric
    metrics = metric.MetricList(metric.Loss(name="loss-ce"))
                        
    # enable cudnn tune
    cudnn.benchmark = True

    net.fit(train_iter=train_iter,
            eval_iter=eval_iter,
            optimizer=optimizer,
            lr_scheduler=lr_scheduler,
            metrics=metrics,
            epoch_start=epoch_start,
            epoch_end=end_epoch,
            Hash_center=Hash_center,)
Exemple #5
0
    args.distributed = args.world_size > 1  # False

    net, input_conf = get_symbol(
        name=args.network,
        pretrained=args.pretrained_2d if args.resume_epoch < 0 else None,
        print_net=True if args.distributed else False,
        hash_bit=args.hash_bit,
        **dataset_cfg)
    net.eval()
    net = torch.nn.DataParallel(net).cuda()
    checkpoint = torch.load(args.pretrained_3d)
    net.load_state_dict(checkpoint['state_dict'])

    train_iter, eval_iter = iterator_factory.creat(
        name=dataset_name,
        batch_size=batch_size,
        clip_length=clip_length,
        train_interval=train_frame_interval,
        val_interval=val_frame_interval)
    print(len(train_iter))
    print(len(eval_iter))
    # print(net)
    print('Generating hash for database.............')
    database_hash, database_labels, dataset_path = predict_hash_code(
        net, train_iter)
    print(database_hash.shape)
    print(database_labels.shape)
    file_dir = 'dataset/' + args.dataset
    np.save(file_dir + '/database_hash.npy', database_hash)
    np.save(file_dir + '/database_label.npy', database_labels)
    np.save(file_dir + '/database_path.npy', dataset_path)
    print('Generating hash for test................')
Exemple #6
0
def train_model(sym_net,
                model_prefix,
                dataset,
                input_conf,
                clip_length=8,
                train_frame_interval=2,
                val_frame_interval=2,
                resume_epoch=-1,
                batch_size=4,
                save_frequency=1,
                lr_base=0.01,
                lr_factor=0.1,
                lr_steps=[400000, 800000],
                end_epoch=1000,
                distributed=False,
                fine_tune=False,
                epoch_div_factor=4,
                precise_bn=False,
                **kwargs):

    assert torch.cuda.is_available(), "Currently, we only support CUDA version"

    # data iterator
    iter_seed = torch.initial_seed() \
                + (torch.distributed.get_rank() * 10 if distributed else 100) \
                + max(0, resume_epoch) * 100
    train_iter, eval_iter = iterator_factory.creat(
        name=dataset,
        batch_size=batch_size,
        clip_length=clip_length,
        train_interval=train_frame_interval,
        val_interval=val_frame_interval,
        mean=input_conf['mean'],
        std=input_conf['std'],
        seed=iter_seed)
    # model (dynamic)
    net = model(
        net=sym_net,
        criterion=torch.nn.CrossEntropyLoss().cuda(),
        model_prefix=model_prefix,
        step_callback_freq=50,
        save_checkpoint_freq=save_frequency,
        opt_batch_size=batch_size,  # optional
        single_checkpoint=
        precise_bn,  # TODO: use shared filesystem to rsync running mean/var
    )
    # if True:
    #     for name, module in net.net.named_modules():
    #         if name.endswith("bn"): module.momentum = 0.005
    net.net.cuda()

    # config optimization, [[w/ wd], [w/o wd]]
    param_base_layers = [[[], []], [[], []]]
    param_new_layers = [[[], []], [[], []]]
    name_freeze_layers, name_base_layers = [], []
    for name, param in net.net.named_parameters():
        idx_wd = 0 if name.endswith('.bias') else 1
        idx_bn = 0 if name.endswith(('.bias', 'bn.weight')) else 1
        if fine_tune:
            if not name.startswith('classifier'):
                param_base_layers[idx_bn][idx_wd].append(param)
                name_base_layers.append(name)
            else:
                param_new_layers[idx_bn][idx_wd].append(param)
        else:
            if "conv_m2" in name:
                param_base_layers[idx_bn][idx_wd].append(param)
                name_base_layers.append(name)
            else:
                param_new_layers[idx_bn][idx_wd].append(param)

    if name_freeze_layers:
        out = "[\'" + '\', \''.join(name_freeze_layers) + "\']"
        logging.info("Optimizer:: >> freezing {} params: {}".format(
            len(name_freeze_layers),
            out if len(out) < 300 else out[0:150] + " ... " + out[-150:]))
    if name_base_layers:
        out = "[\'" + '\', \''.join(name_base_layers) + "\']"
        logging.info(
            "Optimizer:: >> recuding the learning rate of {} params: {}".
            format(
                len(name_base_layers),
                out if len(out) < 300 else out[0:150] + " ... " + out[-150:]))

    if distributed:
        net.net = torch.nn.parallel.DistributedDataParallel(net.net).cuda()
    else:
        net.net = torch.nn.DataParallel(net.net).cuda()

    # optimizer = torch.optim.SGD(sym_net.parameters(),
    wd = 0.0001
    optimizer = custom_optim.SGD(
        [
            {
                'params': param_base_layers[0][0],
                'lr_mult': 0.5,
                'weight_decay': 0.
            },
            {
                'params': param_base_layers[0][1],
                'lr_mult': 0.5,
                'weight_decay': wd
            },
            {
                'params': param_base_layers[1][0],
                'lr_mult': 0.5,
                'weight_decay': 0.,
                'name': 'precise.bn'
            },  # *.bias
            {
                'params': param_base_layers[1][1],
                'lr_mult': 0.5,
                'weight_decay': wd,
                'name': 'precise.bn'
            },  # bn.weight
            {
                'params': param_new_layers[0][0],
                'lr_mult': 1.0,
                'weight_decay': 0.
            },
            {
                'params': param_new_layers[0][1],
                'lr_mult': 1.0,
                'weight_decay': wd
            },
            {
                'params': param_new_layers[1][0],
                'lr_mult': 1.0,
                'weight_decay': 0.,
                'name': 'precise.bn'
            },  # *.bias
            {
                'params': param_new_layers[1][1],
                'lr_mult': 1.0,
                'weight_decay': wd,
                'name': 'precise.bn'
            }
        ],  # bn.weight
        lr=lr_base,
        momentum=0.9,
        nesterov=True)

    # resume: model and optimizer
    if resume_epoch < 0:
        epoch_start = 0
        step_counter = 0
    else:
        net.load_checkpoint(epoch=resume_epoch, optimizer=optimizer)
        epoch_start = resume_epoch
        step_counter = epoch_start * int(
            train_iter.__len__() / epoch_div_factor)

    num_worker = torch.distributed.get_world_size(
    ) if torch.distributed._initialized else 1
    lr_scheduler = MultiFactorScheduler(
        base_lr=lr_base,
        steps=[int(x / (batch_size * num_worker)) for x in lr_steps],
        factor=lr_factor,
        step_counter=step_counter)

    metrics = metric.MetricList(
        metric.Loss(name="loss-ce"),
        metric.Accuracy(name="top1", topk=1),
        metric.Accuracy(name="top5", topk=5),
    )

    cudnn.benchmark = True
    # cudnn.fastest = False
    # cudnn.enabled = False

    net.fit(
        train_iter=train_iter,
        eval_iter=eval_iter,
        optimizer=optimizer,
        lr_scheduler=lr_scheduler,
        metrics=metrics,
        epoch_start=epoch_start,
        epoch_end=end_epoch,
        epoch_div_factor=epoch_div_factor,
        precise_bn=precise_bn,
    )