def train_model(sym_net, model_prefix, dataset, input_conf, clip_length=16, train_frame_interval=2, resume_epoch=-1, batch_size=4, save_frequency=1, lr_base=0.01, lr_factor=0.1, lr_steps=[400000, 800000], end_epoch=1000, distributed=False, fine_tune=False, **kwargs): assert torch.cuda.is_available(), "Currently, we only support CUDA version" # data iterator iter_seed = torch.initial_seed() + 100 + max(0, resume_epoch) * 100 train_iter = iter_fac.creat(name=dataset, batch_size=batch_size, clip_length=clip_length, train_interval=train_frame_interval, mean=input_conf['mean'], std=input_conf['std'], seed=iter_seed) # wapper (dynamic model) net = model( net=sym_net, criterion=nn.CrossEntropyLoss().cuda(), model_prefix=model_prefix, step_callback_freq=50, save_checkpoint_freq=save_frequency, opt_batch_size=batch_size, ) net.net.cuda() # config optimization param_base_layers = [] param_new_layers = [] name_base_layers = [] for name, param in net.net.named_parameters(): if fine_tune: if ('classifier' in name) or ('fc' in name): param_new_layers.append(param) else: param_base_layers.append(param) name_base_layers.append(name) else: param_new_layers.append(param) if name_base_layers: out = "[\'" + '\', \''.join(name_base_layers) + "\']" logging.info( "Optimizer:: >> recuding the learning rate of {} params: {}". format( len(name_base_layers), out if len(out) < 300 else out[0:150] + " ... " + out[-150:])) net.net = torch.nn.DataParallel(net.net).cuda() optimizer = torch.optim.SGD([{ 'params': param_base_layers, 'lr_mult': 0.2 }, { 'params': param_new_layers, 'lr_mult': 1.0 }], lr=lr_base, momentum=0.9, weight_decay=0.0001, nesterov=True) # load params from pretrained 3d network if resume_epoch > 0: logging.info("Initializer:: resuming model from previous training") # resume training: model and optimizer if resume_epoch < 0: epoch_start = 0 step_counter = 0 else: net.load_checkpoint(epoch=resume_epoch, optimizer=optimizer) epoch_start = resume_epoch step_counter = epoch_start * train_iter.__len__() # set learning rate scheduler num_worker = dist.get_world_size() if torch.distributed.is_initialized( ) else 1 lr_scheduler = MultiFactorScheduler( base_lr=lr_base, steps=[int(x / (batch_size * num_worker)) for x in lr_steps], factor=lr_factor, step_counter=step_counter) # define evaluation metric metrics = metric.MetricList( metric.Loss(name="loss-ce"), metric.Accuracy(name="top1", topk=1), metric.Accuracy(name="top5", topk=5), ) net.fit( train_iter=train_iter, optimizer=optimizer, lr_scheduler=lr_scheduler, metrics=metrics, epoch_start=epoch_start, epoch_end=end_epoch, )
def train_model(net_name, sym_net, model_prefix, dataset, input_conf, modality='rgb', split=1, clip_length=16, train_frame_interval=2, val_frame_interval=2, resume_epoch=-1, batch_size=4, save_frequency=1, lr_base=0.01, lr_base2=0.01, lr_d=None, lr_factor=0.1, lr_steps=[400000, 800000], end_epoch=1000, distributed=False, pretrained_3d=None, fine_tune=False, iter_size=1, optim='sgd', accumulate=True, ds_factor=16, epoch_thre=1, score_dir=None, mv_minmaxnorm=False, mv_loadimg=False, detach=False, adv=0, new_classifier=False, **kwargs): assert torch.cuda.is_available(), "Currently, we only support CUDA version" torch.multiprocessing.set_sharing_strategy('file_system') import resource rlimit = resource.getrlimit(resource.RLIMIT_NOFILE) resource.setrlimit(resource.RLIMIT_NOFILE, (2048, rlimit[1])) # data iterator iter_seed = torch.initial_seed() \ + (torch.distributed.get_rank() * 10 if distributed else 100) \ + max(0, resume_epoch) * 100 train_iter, eval_iter = iterator_factory.creat( name=dataset, batch_size=batch_size, clip_length=clip_length, train_interval=train_frame_interval, val_interval=val_frame_interval, mean=input_conf['mean'], std=input_conf['std'], seed=iter_seed, modality=modality, split=split, net_name=net_name, accumulate=accumulate, ds_factor=ds_factor, mv_minmaxnorm=mv_minmaxnorm, mv_loadimg=mv_loadimg) #define an instance of class model net = model( net=sym_net, criterion=torch.nn.CrossEntropyLoss().cuda(), model_prefix=model_prefix, step_callback_freq=50, save_checkpoint_freq=save_frequency, opt_batch_size=batch_size, # optional criterion2=torch.nn.MSELoss().cuda() if modality == 'flow+mp4' else None, criterion3=torch.nn.CrossEntropyLoss().cuda() if adv > 0. else None, adv=adv, ) net.net.cuda() print(torch.cuda.current_device(), torch.cuda.device_count()) # config optimization param_base_layers = [] param_new_layers = [] name_base_layers = [] params_gf = [] params_d = [] for name, param in net.net.named_parameters(): if modality == 'flow+mp4': if name.startswith('gen_flow_model'): params_gf.append(param) elif name.startswith('discriminator'): params_d.append(param) else: if (name.startswith('conv3d_0c_1x1') or name.startswith('classifier')): #if name.startswith('classifier'): param_new_layers.append(param) else: param_base_layers.append(param) name_base_layers.append(name) #else: # #print(name) # param_new_layers.append(param) else: if fine_tune: if name.startswith('classifier') or name.startswith( 'conv3d_0c_1x1'): #if name.startswith('classifier'): param_new_layers.append(param) else: param_base_layers.append(param) name_base_layers.append(name) else: param_new_layers.append(param) if modality == 'flow+mp4': if fine_tune: lr_mul = 0.2 else: lr_mul = 0.5 else: lr_mul = 0.2 #print(params_d) if name_base_layers: out = "[\'" + '\', \''.join(name_base_layers) + "\']" logging.info( "Optimizer:: >> recuding the learning rate of {} params: {} by factor {}" .format( len(name_base_layers), out if len(out) < 300 else out[0:150] + " ... " + out[-150:], lr_mul)) if net_name == 'I3D': weight_decay = 0.0001 else: raise ValueError('UNKOWN net_name', net_name) logging.info("Train_Model:: weight_decay: `{}'".format(weight_decay)) if distributed: net.net = torch.nn.parallel.DistributedDataParallel(net.net).cuda() else: net.net = torch.nn.DataParallel(net.net).cuda() if optim == 'adam': optimizer = torch.optim.Adam([{ 'params': param_base_layers, 'lr_mult': lr_mul }, { 'params': param_new_layers, 'lr_mult': 1.0 }], lr=lr_base, weight_decay=weight_decay) optimizer_2 = torch.optim.Adam([{ 'params': param_base_layers, 'lr_mult': lr_mul }, { 'params': param_new_layers, 'lr_mult': 1.0 }], lr=lr_base2, weight_decay=weight_decay) else: optimizer = torch.optim.SGD([{ 'params': param_base_layers, 'lr_mult': lr_mul }, { 'params': param_new_layers, 'lr_mult': 1.0 }], lr=lr_base, momentum=0.9, weight_decay=weight_decay, nesterov=True) optimizer_2 = torch.optim.SGD([{ 'params': param_base_layers, 'lr_mult': lr_mul }, { 'params': param_new_layers, 'lr_mult': 1.0 }], lr=lr_base2, momentum=0.9, weight_decay=weight_decay, nesterov=True) if adv > 0.: optimizer_3 = torch.optim.Adam(params_d, lr=lr_base, weight_decay=weight_decay, eps=0.001) else: optimizer_3 = None if modality == 'flow+mp4': if optim == 'adam': optimizer_mse = torch.optim.Adam(params_gf, lr=lr_base, weight_decay=weight_decay, eps=1e-08) optimizer_mse_2 = torch.optim.Adam(params_gf, lr=lr_base2, weight_decay=weight_decay, eps=0.001) else: optimizer_mse = torch.optim.SGD(params_gf, lr=lr_base, momentum=0.9, weight_decay=weight_decay, nesterov=True) optimizer_mse_2 = torch.optim.SGD(params_gf, lr=lr_base2, momentum=0.9, weight_decay=weight_decay, nesterov=True) else: optimizer_mse = None optimizer_mse_2 = None # load params from pretrained 3d network if pretrained_3d and not pretrained_3d == 'False': if resume_epoch < 0: assert os.path.exists(pretrained_3d), "cannot locate: `{}'".format( pretrained_3d) logging.info( "Initializer:: loading model states from: `{}'".format( pretrained_3d)) if net_name == 'I3D': checkpoint = torch.load(pretrained_3d) keys = list(checkpoint.keys()) state_dict = {} for name in keys: state_dict['module.' + name] = checkpoint[name] del checkpoint net.load_state(state_dict, strict=False) if new_classifier: checkpoint = torch.load( './network/pretrained/model_flow.pth') keys = list(checkpoint.keys()) state_dict = {} for name in keys: state_dict['module.' + name] = checkpoint[name] del checkpoint net.load_state(state_dict, strict=False) else: checkpoint = torch.load(pretrained_3d) net.load_state(checkpoint['state_dict'], strict=False) else: logging.info( "Initializer:: skip loading model states from: `{}'" + ", since it's going to be overwrited by the resumed model". format(pretrained_3d)) # resume training: model and optimizer if resume_epoch < 0: epoch_start = 0 step_counter = 0 else: net.load_checkpoint(epoch=resume_epoch, optimizer=optimizer, optimizer_mse=optimizer_mse) epoch_start = resume_epoch step_counter = epoch_start * train_iter.__len__() # set learning rate scheduler num_worker = dist.get_world_size() if torch.distributed._initialized else 1 lr_scheduler = MultiFactorScheduler( base_lr=lr_base, steps=[int(x / (batch_size * num_worker)) for x in lr_steps], factor=lr_factor, step_counter=step_counter) if modality == 'flow+mp4': lr_scheduler2 = MultiFactorScheduler( base_lr=lr_base2, steps=[int(x / (batch_size * num_worker)) for x in lr_steps], factor=lr_factor, step_counter=step_counter) if lr_d == None: lr_scheduler3 = MultiFactorScheduler( base_lr=lr_d, steps=[int(x / (batch_size * num_worker)) for x in lr_steps], factor=lr_factor, step_counter=step_counter) else: print("_____________", lr_d) lr_scheduler3 = MultiFactorScheduler( base_lr=lr_d, steps=[int(x / (batch_size * num_worker)) for x in lr_steps], factor=lr_factor, step_counter=step_counter) else: lr_scheduler2 = None lr_scheduler3 = None # define evaluation metric metrics_D = None if modality == 'flow+mp4': metrics = metric.MetricList( metric.Loss(name="loss-ce"), metric.Loss(name="loss-mse"), metric.Accuracy(name="top1", topk=1), metric.Accuracy(name="top5", topk=5), ) if adv > 0: metrics_D = metric.MetricList(metric.Loss(name="classi_D"), metric.Loss(name="adv_D")) else: metrics = metric.MetricList( metric.Loss(name="loss-ce"), metric.Accuracy(name="top1", topk=1), metric.Accuracy(name="top5", topk=5), ) # enable cudnn tune cudnn.benchmark = True net.fit(train_iter=train_iter, eval_iter=eval_iter, optimizer=optimizer, lr_scheduler=lr_scheduler, metrics=metrics, epoch_start=epoch_start, epoch_end=end_epoch, iter_size=iter_size, optimizer_mse=optimizer_mse, optimizer_2=optimizer_2, optimizer_3=optimizer_3, optimizer_mse_2=optimizer_mse_2, lr_scheduler2=lr_scheduler2, lr_scheduler3=lr_scheduler3, metrics_D=metrics_D, epoch_thre=epoch_thre, score_dir=score_dir, detach=detach)
def train_model(sym_net, model_prefix, dataset, input_conf, clip_length=16, train_frame_interval=2, val_frame_interval=2, resume_epoch=-1, batch_size=4, save_frequency=1, lr_base=0.01, lr_factor=0.1, lr_steps=[400000, 800000], end_epoch=1000, distributed=False, pretrained_3d=None, fine_tune=False, load_from_frames=False, use_flow=False, triplet_loss=False, **kwargs): assert torch.cuda.is_available(), "Currently, we only support CUDA version" # data iterator iter_seed = torch.initial_seed() \ + (torch.distributed.get_rank() * 10 if distributed else 100) \ + max(0, resume_epoch) * 100 train_iter, eval_iter = iterator_factory.creat(name=dataset, batch_size=batch_size, clip_length=clip_length, train_interval=train_frame_interval, val_interval=val_frame_interval, mean=input_conf['mean'], std=input_conf['std'], seed=iter_seed, load_from_frames=load_from_frames, use_flow=use_flow) # wapper (dynamic model) if use_flow: class LogNLLLoss(torch.nn.Module): def __init__(self): super(LogNLLLoss, self).__init__() self.loss = torch.nn.NLLLoss() def forward(self, output, target): output = torch.log(output) loss = self.loss(output, target) return loss # criterion = LogNLLLoss().cuda() criterion = torch.nn.CrossEntropyLoss().cuda() elif triplet_loss: logging.info("Using triplet loss") criterion=torch.nn.MarginRankingLoss().cuda() else: criterion = torch.nn.CrossEntropyLoss().cuda() net = model(net=sym_net, criterion=criterion, triplet_loss=triplet_loss, model_prefix=model_prefix, step_callback_freq=50, save_checkpoint_freq=save_frequency, opt_batch_size=batch_size, # optional ) net.net.cuda() # config optimization param_base_layers = [] param_new_layers = [] name_base_layers = [] for name, param in net.net.named_parameters(): if fine_tune: # if name.startswith('classifier'): if 'classifier' in name or 'fc' in name: param_new_layers.append(param) else: param_base_layers.append(param) name_base_layers.append(name) else: param_new_layers.append(param) if name_base_layers: out = "[\'" + '\', \''.join(name_base_layers) + "\']" logging.info("Optimizer:: >> recuding the learning rate of {} params: {}".format(len(name_base_layers), out if len(out) < 300 else out[0:150] + " ... " + out[-150:])) if distributed: net.net = torch.nn.parallel.DistributedDataParallel(net.net).cuda() else: net.net = torch.nn.DataParallel(net.net).cuda() optimizer = torch.optim.SGD([{'params': param_base_layers, 'lr_mult': 0.2}, {'params': param_new_layers, 'lr_mult': 1.0}], lr=lr_base, momentum=0.9, weight_decay=0.0001, nesterov=True) # load params from pretrained 3d network if pretrained_3d: if resume_epoch < 0: if os.path.exists(pretrained_3d): # assert os.path.exists(pretrained_3d), "cannot locate: '{}'".format(pretrained_3d) logging.info("Initializer:: loading model states from: `{}'".format(pretrained_3d)) checkpoint = torch.load(pretrained_3d) net.load_state(checkpoint['state_dict'], mode='ada') else: logging.warning("cannot locate: '{}'".format(pretrained_3d)) else: logging.info("Initializer:: skip loading model states from: `{}'" + ", since it's going to be overwrited by the resumed model".format(pretrained_3d)) # resume training: model and optimizer if resume_epoch < 0: epoch_start = 0 step_counter = 0 else: net.load_checkpoint(epoch=resume_epoch, optimizer=optimizer) epoch_start = resume_epoch step_counter = epoch_start * train_iter.__len__() # set learning rate scheduler num_worker = 1 lr_scheduler = MultiFactorScheduler(base_lr=lr_base, steps=[int(x/(batch_size*num_worker)) for x in lr_steps], factor=lr_factor, step_counter=step_counter) # define evaluation metric if triplet_loss: metrics = metric.MetricList(metric.Loss(name="loss-triplet"), metric.TripletAccuracy(name="acc"), ) else: metrics = metric.MetricList(metric.Loss(name="loss-ce"), metric.Accuracy(name="top1", topk=1), metric.Accuracy(name="top5", topk=5), ) # enable cudnn tune cudnn.benchmark = True net.fit(train_iter=train_iter, eval_iter=eval_iter, optimizer=optimizer, lr_scheduler=lr_scheduler, metrics=metrics, epoch_start=epoch_start, epoch_end=end_epoch,)
def train_model(Hash_center, sym_net, model_prefix, dataset, input_conf, hash_bit, clip_length=16, train_frame_interval=2, val_frame_interval=2, resume_epoch=-1, batch_size=4, save_frequency=1, lr_base=0.01, lr_factor=0.1, lr_steps=[400000, 800000], end_epoch=1000, distributed=False, pretrained_3d=None, fine_tune=False, **kwargs): assert torch.cuda.is_available(), "Currently, we only support CUDA version" # data iterator iter_seed = torch.initial_seed() \ + (torch.distributed.get_rank() * 10 if distributed else 100) \ + max(0, resume_epoch) * 100 train_iter, eval_iter = iterator_factory.creat(name=dataset, batch_size=batch_size, clip_length=clip_length, train_interval=train_frame_interval, val_interval=val_frame_interval, mean=input_conf['mean'], std=input_conf['std'], seed=iter_seed) print(len(train_iter)) print(len(eval_iter)) # wapper (dynamic model) net = model(net=sym_net, criterion=torch.nn.BCELoss().cuda(), model_prefix=model_prefix, step_callback_freq=50, save_checkpoint_freq=save_frequency, opt_batch_size=batch_size, # optional dataset=dataset, # dataset name hash_bit=hash_bit, ) net.net.cuda() # config optimization param_base_layers = [] param_new_layers = [] name_base_layers = [] for name, param in net.net.named_parameters(): if fine_tune: #print(f'fine tune {fine_tune}') if name.startswith('hash'): param_new_layers.append(param) else: param_base_layers.append(param) name_base_layers.append(name) else: param_new_layers.append(param) if name_base_layers: out = "[\'" + '\', \''.join(name_base_layers) + "\']" logging.info("Optimizer:: >> recuding the learning rate of {} params: {}".format(len(name_base_layers), out if len(out) < 300 else out[0:150] + " ... " + out[-150:])) if distributed: net.net = torch.nn.parallel.DistributedDataParallel(net.net).cuda() else: net.net = torch.nn.DataParallel(net.net).cuda() optimizer = torch.optim.SGD([{'params': param_base_layers, 'lr_mult': 0.2}, {'params': param_new_layers, 'lr_mult': 1.0}], lr=lr_base, momentum=0.9, weight_decay=0.0001, nesterov=True) # load params from pretrained 3d network if pretrained_3d: if resume_epoch < 0: assert os.path.exists(pretrained_3d), "cannot locate: `{}'".format(pretrained_3d) logging.info("Initializer:: loading model states from: `{}'".format(pretrained_3d)) checkpoint = torch.load(pretrained_3d) net.load_state(checkpoint['state_dict'], strict=False) else: logging.info("Initializer:: skip loading model states from: `{}'" + ", since it's going to be overwrited by the resumed model".format(pretrained_3d)) # resume training: model and optimizer if resume_epoch < 0: epoch_start = 0 step_counter = 0 else: net.load_checkpoint(epoch=resume_epoch, optimizer=optimizer) epoch_start = resume_epoch step_counter = epoch_start * train_iter.__len__() # set learning rate scheduler num_worker = dist.get_world_size() if torch.distributed._initialized else 1 lr_scheduler = MultiFactorScheduler(base_lr=lr_base, steps=[int(x/(batch_size*num_worker)) for x in lr_steps], factor=lr_factor, step_counter=step_counter) # define evaluation metric metrics = metric.MetricList(metric.Loss(name="loss-ce")) # enable cudnn tune cudnn.benchmark = True net.fit(train_iter=train_iter, eval_iter=eval_iter, optimizer=optimizer, lr_scheduler=lr_scheduler, metrics=metrics, epoch_start=epoch_start, epoch_end=end_epoch, Hash_center=Hash_center,)
args.distributed = args.world_size > 1 # False net, input_conf = get_symbol( name=args.network, pretrained=args.pretrained_2d if args.resume_epoch < 0 else None, print_net=True if args.distributed else False, hash_bit=args.hash_bit, **dataset_cfg) net.eval() net = torch.nn.DataParallel(net).cuda() checkpoint = torch.load(args.pretrained_3d) net.load_state_dict(checkpoint['state_dict']) train_iter, eval_iter = iterator_factory.creat( name=dataset_name, batch_size=batch_size, clip_length=clip_length, train_interval=train_frame_interval, val_interval=val_frame_interval) print(len(train_iter)) print(len(eval_iter)) # print(net) print('Generating hash for database.............') database_hash, database_labels, dataset_path = predict_hash_code( net, train_iter) print(database_hash.shape) print(database_labels.shape) file_dir = 'dataset/' + args.dataset np.save(file_dir + '/database_hash.npy', database_hash) np.save(file_dir + '/database_label.npy', database_labels) np.save(file_dir + '/database_path.npy', dataset_path) print('Generating hash for test................')
def train_model(sym_net, model_prefix, dataset, input_conf, clip_length=8, train_frame_interval=2, val_frame_interval=2, resume_epoch=-1, batch_size=4, save_frequency=1, lr_base=0.01, lr_factor=0.1, lr_steps=[400000, 800000], end_epoch=1000, distributed=False, fine_tune=False, epoch_div_factor=4, precise_bn=False, **kwargs): assert torch.cuda.is_available(), "Currently, we only support CUDA version" # data iterator iter_seed = torch.initial_seed() \ + (torch.distributed.get_rank() * 10 if distributed else 100) \ + max(0, resume_epoch) * 100 train_iter, eval_iter = iterator_factory.creat( name=dataset, batch_size=batch_size, clip_length=clip_length, train_interval=train_frame_interval, val_interval=val_frame_interval, mean=input_conf['mean'], std=input_conf['std'], seed=iter_seed) # model (dynamic) net = model( net=sym_net, criterion=torch.nn.CrossEntropyLoss().cuda(), model_prefix=model_prefix, step_callback_freq=50, save_checkpoint_freq=save_frequency, opt_batch_size=batch_size, # optional single_checkpoint= precise_bn, # TODO: use shared filesystem to rsync running mean/var ) # if True: # for name, module in net.net.named_modules(): # if name.endswith("bn"): module.momentum = 0.005 net.net.cuda() # config optimization, [[w/ wd], [w/o wd]] param_base_layers = [[[], []], [[], []]] param_new_layers = [[[], []], [[], []]] name_freeze_layers, name_base_layers = [], [] for name, param in net.net.named_parameters(): idx_wd = 0 if name.endswith('.bias') else 1 idx_bn = 0 if name.endswith(('.bias', 'bn.weight')) else 1 if fine_tune: if not name.startswith('classifier'): param_base_layers[idx_bn][idx_wd].append(param) name_base_layers.append(name) else: param_new_layers[idx_bn][idx_wd].append(param) else: if "conv_m2" in name: param_base_layers[idx_bn][idx_wd].append(param) name_base_layers.append(name) else: param_new_layers[idx_bn][idx_wd].append(param) if name_freeze_layers: out = "[\'" + '\', \''.join(name_freeze_layers) + "\']" logging.info("Optimizer:: >> freezing {} params: {}".format( len(name_freeze_layers), out if len(out) < 300 else out[0:150] + " ... " + out[-150:])) if name_base_layers: out = "[\'" + '\', \''.join(name_base_layers) + "\']" logging.info( "Optimizer:: >> recuding the learning rate of {} params: {}". format( len(name_base_layers), out if len(out) < 300 else out[0:150] + " ... " + out[-150:])) if distributed: net.net = torch.nn.parallel.DistributedDataParallel(net.net).cuda() else: net.net = torch.nn.DataParallel(net.net).cuda() # optimizer = torch.optim.SGD(sym_net.parameters(), wd = 0.0001 optimizer = custom_optim.SGD( [ { 'params': param_base_layers[0][0], 'lr_mult': 0.5, 'weight_decay': 0. }, { 'params': param_base_layers[0][1], 'lr_mult': 0.5, 'weight_decay': wd }, { 'params': param_base_layers[1][0], 'lr_mult': 0.5, 'weight_decay': 0., 'name': 'precise.bn' }, # *.bias { 'params': param_base_layers[1][1], 'lr_mult': 0.5, 'weight_decay': wd, 'name': 'precise.bn' }, # bn.weight { 'params': param_new_layers[0][0], 'lr_mult': 1.0, 'weight_decay': 0. }, { 'params': param_new_layers[0][1], 'lr_mult': 1.0, 'weight_decay': wd }, { 'params': param_new_layers[1][0], 'lr_mult': 1.0, 'weight_decay': 0., 'name': 'precise.bn' }, # *.bias { 'params': param_new_layers[1][1], 'lr_mult': 1.0, 'weight_decay': wd, 'name': 'precise.bn' } ], # bn.weight lr=lr_base, momentum=0.9, nesterov=True) # resume: model and optimizer if resume_epoch < 0: epoch_start = 0 step_counter = 0 else: net.load_checkpoint(epoch=resume_epoch, optimizer=optimizer) epoch_start = resume_epoch step_counter = epoch_start * int( train_iter.__len__() / epoch_div_factor) num_worker = torch.distributed.get_world_size( ) if torch.distributed._initialized else 1 lr_scheduler = MultiFactorScheduler( base_lr=lr_base, steps=[int(x / (batch_size * num_worker)) for x in lr_steps], factor=lr_factor, step_counter=step_counter) metrics = metric.MetricList( metric.Loss(name="loss-ce"), metric.Accuracy(name="top1", topk=1), metric.Accuracy(name="top5", topk=5), ) cudnn.benchmark = True # cudnn.fastest = False # cudnn.enabled = False net.fit( train_iter=train_iter, eval_iter=eval_iter, optimizer=optimizer, lr_scheduler=lr_scheduler, metrics=metrics, epoch_start=epoch_start, epoch_end=end_epoch, epoch_div_factor=epoch_div_factor, precise_bn=precise_bn, )