def build_model(self) -> nn.Module: opt = get_cli_args(batch_size=pedl_batch_size, prebias=pedl_prebias, accumulate=pedl_accumulate) hyp = get_hyp(lr0=pedl_init_lr) # Initialize model model = Darknet(opt.cfg, arc=opt.arc) # .to(device) # Fetch starting weights # TODO Once download_data_fn is implemented this should go into download_data attempt_download(opt.weights) chkpt = torch.load(opt.weights) # load model try: chkpt["model"] = { k: v for k, v in chkpt["model"].items() if model.state_dict()[k].numel() == v.numel() } model.load_state_dict(chkpt["model"], strict=False) except KeyError as e: s = ( "%s is not compatible with %s. Specify --weights '' or specify a --cfg compatible with %s. " "See https://github.com/ultralytics/yolov3/issues/657" % (opt.weights, opt.cfg, opt.weights)) raise KeyError(s) from e del chkpt data_dict = get_data_cfg() nc = 1 if opt.single_cls else int(data_dict["classes"]) model.nc = nc # attach number of classes to model model.arc = opt.arc # attach yolo architecture model.hyp = hyp # attach hyperparameters to model train_dataset = LazyModule.get() # The model class weights depend on the dataset labels model.class_weights = labels_to_class_weights( train_dataset.labels, nc) # attach class weights return model
def greedy_channel_select(origin_model, prune_cfg, origin_weights, select_layer, device, aux_util, data_loader, pruned_rate): init_state_dict = mask_converted(prune_cfg, origin_weights, target=None) prune_model = Darknet(prune_cfg).to(device) prune_model.load_state_dict(init_state_dict, strict=True) del init_state_dict solve_sub_problem_optimizer = optim.SGD( prune_model.module_list[int(select_layer)].MaskConv2d.parameters(), lr=hyp['lr0'], momentum=hyp['momentum']) hook_util = HookUtils() handles = [] info = aux_util.layer_info[int(select_layer)] in_channels = info['in_channels'] remove_k = math.floor(in_channels * pruned_rate) k = in_channels - remove_k for name, child in origin_model.module_list.named_children(): if name == select_layer: handles.append( child.BatchNorm2d.register_forward_hook( hook_util.hook_origin_input)) aux_idx = aux_util.conv_layer_dict[select_layer] hook_layer_aux = aux_util.down_sample_layer[aux_idx] for name, child in prune_model.module_list.named_children(): if name == select_layer: handles.append( child.BatchNorm2d.register_forward_hook( hook_util.hook_prune_input)) elif name == hook_layer_aux: handles.append( child.register_forward_hook(hook_util.hook_prune_input)) aux_net = aux_util.creat_aux_list(416, device, conv_layer_name=select_layer) chkpt_aux = torch.load(aux_weight, map_location=device) aux_net.load_state_dict(chkpt_aux['aux{}'.format(aux_idx)]) del chkpt_aux if device.type != 'cpu' and torch.cuda.device_count() > 1: prune_model = torch.nn.parallel.DistributedDataParallel( prune_model, find_unused_parameters=True) prune_model.yolo_layers = prune_model.module.yolo_layers aux_net = torch.nn.parallel.DistributedDataParallel( aux_net, find_unused_parameters=True) nb = len(data_loader) prune_model.nc = 80 prune_model.hyp = hyp prune_model.arc = 'default' prune_model.eval() aux_net.eval() MSE = nn.MSELoss(reduction='mean') greedy = torch.zeros(k) for i_k in range(k): pbar = tqdm(enumerate(data_loader), total=nb) print(('\n' + '%10s' * 8) % ('Stage', 'gpu_mem', 'iter', 'MSELoss', 'PdLoss', 'AuxLoss', 'Total', 'targets')) for i, (imgs, targets, _, _) in pbar: if len(targets) == 0: continue imgs = imgs.to(device).float( ) / 255.0 # uint8 to float32, 0 - 255 to 0.0 - 1.0 targets = targets.to(device) with torch.no_grad(): _ = origin_model(imgs) _, pruning_pred = prune_model(imgs) pruning_loss, _ = compute_loss(pruning_pred, targets, prune_model) hook_util.cat_to_gpu0('prune') aux_pred = aux_net(hook_util.prune_features['gpu0'][1]) aux_loss, _ = AuxNetUtils.compute_loss_for_aux( aux_pred, aux_net, targets) mse_loss = torch.zeros(1).to(device) mse_loss += MSE(hook_util.prune_features['gpu0'][0], hook_util.origin_features['gpu0'][0]) loss = hyp['joint_loss'] * mse_loss + pruning_loss + aux_loss loss.backward() mem = torch.cuda.memory_cached() / 1E9 if torch.cuda.is_available( ) else 0 s = ('%10s' * 3 + '%10.3g' * 5) % ( 'Pruning ' + select_layer, '%.3gG' % mem, '%g/%g' % (i_k, k), mse_loss, pruning_loss, aux_loss, loss, len(targets)) pbar.set_description(s) hook_util.clean_hook_out('origin') hook_util.clean_hook_out('prune') grad = prune_model.module.module_list[int( select_layer)].MaskConv2d.weight.grad.detach().clone()**2 grad = grad.sum((2, 3)).sqrt().sum(0) if i_k == 0: prune_model.module.module_list[int( select_layer)].MaskConv2d.selected_channels_mask[:] = 1e-5 _, non_greedy_indices = torch.topk(grad, k) logger.info('non greedy layer{}: selected==>{}'.format( select_layer, str(non_greedy_indices))) selected_channels_mask = prune_model.module.module_list[int( select_layer)].MaskConv2d.selected_channels_mask _, indices = torch.topk(grad * (1 - selected_channels_mask), 1) prune_model.module.module_list[int( select_layer)].MaskConv2d.selected_channels_mask[indices] = 1 greedy[i_k] = indices logger.info('greedy layer{} iter{}: indices==>{}'.format( select_layer, str(i_k), str(indices))) prune_model.zero_grad() pbar = tqdm(enumerate(data_loader), total=nb) mloss = torch.zeros(4).to(device) print(('\n' + '%10s' * 8) % ('Stage', 'gpu_mem', 'iter', 'MSELoss', 'PdLoss', 'AuxLoss', 'Total', 'targets')) for i, (imgs, targets, _, _) in pbar: if len(targets) == 0: continue imgs = imgs.to(device).float( ) / 255.0 # uint8 to float32, 0 - 255 to 0.0 - 1.0 targets = targets.to(device) with torch.no_grad(): _ = origin_model(imgs) _, pruning_pred = prune_model(imgs) pruning_loss, _ = compute_loss(pruning_pred, targets, prune_model) hook_util.cat_to_gpu0('prune') aux_pred = aux_net(hook_util.prune_features['gpu0'][1]) aux_loss, _ = AuxNetUtils.compute_loss_for_aux( aux_pred, aux_net, targets) mse_loss = torch.zeros(1).to(device) mse_loss += MSE(hook_util.prune_features['gpu0'][0], hook_util.origin_features['gpu0'][0]) loss = hyp['joint_loss'] * mse_loss + pruning_loss + aux_loss loss.backward() solve_sub_problem_optimizer.step() solve_sub_problem_optimizer.zero_grad() mem = torch.cuda.memory_cached() / 1E9 if torch.cuda.is_available( ) else 0 mloss = (mloss * i + torch.cat( [mse_loss, pruning_loss, aux_loss, loss]).detach()) / (i + 1) s = ('%10s' * 3 + '%10.3g' * 5) % ('SubProm ' + select_layer, '%.3gG' % mem, '%g/%g' % (i_k, k), *mloss, len(targets)) pbar.set_description(s) hook_util.clean_hook_out('origin') hook_util.clean_hook_out('prune') for handle in handles: handle.remove() logger.info( ("greedy layer{}: selected==>{}".format(select_layer, str(greedy))))
def fine_tune(prune_cfg, data, aux_util, device, train_loader, test_loader, epochs=10): with open(progress_result, 'a') as f: f.write(('\n' + '%10s' * 10 + '\n') % ('Stage', 'Epoch', 'DIoU', 'obj', 'cls', 'Total', 'P', 'R', '[email protected]', 'F1')) batch_size = train_loader.batch_size img_size = train_loader.dataset.img_size accumulate = 64 // batch_size hook_util = HookUtils() pruned_model = Darknet(prune_cfg, img_size=(img_size, img_size)).to(device) chkpt = torch.load(progress_chkpt, map_location=device) pruned_model.load_state_dict(chkpt['model'], strict=True) current_layer = chkpt['current_layer'] aux_in_layer = aux_util.conv_layer_dict[current_layer] aux_model = aux_util.creat_aux_model(aux_in_layer) aux_model.to(device) aux_model.load_state_dict(chkpt['aux_in{}'.format(aux_in_layer)], strict=True) aux_loss_scalar = max(0.01, pow((int(aux_in_layer) + 1) / 75, 2)) start_epoch = chkpt['epoch'] + 1 if start_epoch == epochs: return current_layer # fine tune 完毕,返回需要修剪的层名 pg0, pg1 = [], [] # optimizer parameter groups for k, v in dict(pruned_model.named_parameters()).items(): if 'MaskConv2d.weight' in k: pg1 += [v] # parameter group 1 (apply weight_decay) else: pg0 += [v] # parameter group 0 for v in aux_model.parameters(): pg0 += [v] # parameter group 0 optimizer = optim.SGD(pg0, lr=hyp['lr0'], momentum=hyp['momentum'], nesterov=True) optimizer.add_param_group({ 'params': pg1, 'weight_decay': hyp['weight_decay'] }) # add pg1 with weight_decay del pg0, pg1 if chkpt['optimizer'] is not None: optimizer.load_state_dict(chkpt['optimizer']) del chkpt scheduler = lr_scheduler.MultiStepLR( optimizer, milestones=[epochs // 3, 2 * (epochs // 3)], gamma=0.1) scheduler.last_epoch = start_epoch - 1 if device.type != 'cpu' and torch.cuda.device_count() > 1: pruned_model = nn.parallel.DistributedDataParallel( pruned_model, find_unused_parameters=True) pruned_model.yolo_layers = pruned_model.module.yolo_layers # -------------start train------------- nb = len(train_loader) pruned_model.nc = 80 pruned_model.hyp = hyp pruned_model.arc = 'default' for epoch in range(start_epoch, epochs): # -------------register hook for model------------- for name, child in pruned_model.module.module_list.named_children(): if name == aux_in_layer: handle = child.register_forward_hook( hook_util.hook_prune_output) # -------------register hook for model------------- pruned_model.train() aux_model.train() print(('\n' + '%10s' * 7) % ('Stage', 'Epoch', 'gpu_mem', 'DIoU', 'obj', 'cls', 'total')) # -------------start batch------------- mloss = torch.zeros(4).to(device) pbar = tqdm(enumerate(train_loader), total=nb) for i, (img, targets, _, _) in pbar: if len(targets) == 0: continue ni = nb * epoch + i img = img.to(device).float() / 255.0 targets = targets.to(device) pruned_pred = pruned_model(img) pruned_loss, pruned_loss_items = compute_loss( pruned_pred, targets, pruned_model) pruned_loss *= batch_size / 64 hook_util.cat_to_gpu0() aux_pred = aux_model(hook_util.prune_features['gpu0'][0], targets) aux_loss = compute_loss_for_DCP(aux_pred, targets) aux_loss *= aux_loss_scalar * batch_size / 64 loss = pruned_loss + aux_loss loss.backward() hook_util.clean_hook_out() if ni % accumulate == 0: optimizer.step() optimizer.zero_grad() pruned_loss_items[2] += aux_loss.item() mloss = (mloss * i + pruned_loss_items) / (i + 1) mem = torch.cuda.memory_cached() / 1E9 if torch.cuda.is_available( ) else 0 s = ('%10s' * 3 + '%10.3g' * 4) % ('FiTune ' + current_layer, '%g/%g' % (epoch, epochs - 1), '%.3gG' % mem, *mloss) pbar.set_description(s) # -------------end batch------------- scheduler.step() handle.remove() results, _ = test.test(prune_cfg, data, batch_size=batch_size * 2, img_size=416, model=pruned_model, conf_thres=0.1, iou_thres=0.5, save_json=False, dataloader=test_loader) """ chkpt = {'current_layer': 'epoch': 'model': 'optimizer': 'aux_in12': 'aux_in37': 'aux_in62': 'aux_in75': 'prune_guide':} """ chkpt = torch.load(progress_chkpt, map_location=device) chkpt['current_layer'] = current_layer chkpt['epoch'] = epoch chkpt['model'] = pruned_model.module.state_dict() if type( pruned_model ) is nn.parallel.DistributedDataParallel else pruned_model.state_dict( ) chkpt[ 'optimizer'] = None if epoch == epochs - 1 else optimizer.state_dict( ) chkpt['aux_in{}'.format(aux_in_layer)] = aux_model.state_dict() torch.save(chkpt, progress_chkpt) torch.save(chkpt, last) if epoch == epochs - 1: torch.save(chkpt, '../weights/DCP/backup{}.pt'.format(current_layer)) del chkpt with open(progress_result, 'a') as f: f.write(('%10s' * 2 + '%10.3g' * 8) % ('FiTune ' + current_layer, '%g/%g' % (epoch, epochs - 1), *mloss, *results[:4]) + '\n') # -------------end train------------- torch.cuda.empty_cache() return current_layer
def channels_select(prune_cfg, data, origin_model, aux_util, device, data_loader, select_layer, pruned_rate): with open(progress_result, 'a') as f: f.write(('\n' + '%10s' * 9 + '\n') % ('Stage', 'Change', 'MSELoss', 'AuxLoss', 'Total', 'P', 'R', '[email protected]', 'F1')) logger.info(('%10s' * 6) % ('Stage', 'Channels', 'Batch', 'MSELoss', 'AuxLoss', 'Total')) batch_size = data_loader.batch_size img_size = data_loader.dataset.img_size accumulate = 64 // batch_size hook_util = HookUtils() handles = [] n_iter = math.floor(500 / batch_size) pruning_model = Darknet(prune_cfg, img_size=(img_size, img_size)).to(device) chkpt = torch.load(progress_chkpt, map_location=device) pruning_model.load_state_dict(chkpt['model'], strict=True) aux_in_layer = aux_util.conv_layer_dict[select_layer] aux_model = aux_util.creat_aux_model(aux_in_layer) aux_model.to(device) aux_model.load_state_dict(chkpt['aux_in{}'.format(aux_in_layer)], strict=True) aux_loss_scalar = max(0.01, pow((int(aux_in_layer) + 1) / 75, 2)) del chkpt solve_sub_problem_optimizer = optim.SGD( pruning_model.module_list[int(aux_in_layer)].MaskConv2d.parameters(), lr=hyp['lr0'], momentum=hyp['momentum']) for name, child in origin_model.module_list.named_children(): if name == aux_in_layer: handles.append( child.register_forward_hook(hook_util.hook_origin_output)) if name == select_layer: handles.append( child.register_forward_hook(hook_util.hook_origin_output)) for name, child in pruning_model.module_list.named_children(): if name == aux_in_layer: handles.append( child.register_forward_hook(hook_util.hook_prune_output)) if name == select_layer: handles.append( child.register_forward_hook(hook_util.hook_prune_output)) if device.type != 'cpu' and torch.cuda.device_count() > 1: origin_model = torch.nn.parallel.DistributedDataParallel( origin_model, find_unused_parameters=True) origin_model.yolo_layers = origin_model.module.yolo_layers pruning_model = torch.nn.parallel.DistributedDataParallel( pruning_model, find_unused_parameters=True) pruning_model.yolo_layers = pruning_model.module.yolo_layers retain_channels_num = math.floor( aux_util.layer_info[select_layer]["in_channels"] * (1 - pruned_rate)) pruning_model.nc = 80 pruning_model.hyp = hyp pruning_model.arc = 'default' pruning_model.eval() aux_model.eval() MSE = nn.MSELoss(reduction='mean') mloss = torch.zeros(3).to(device) for i_k in range(retain_channels_num): data_iter = iter(data_loader) pbar = tqdm(range(n_iter), total=n_iter) print(('\n' + '%10s' * 6) % ('Stage', 'gpu_mem', 'channels', 'MSELoss', 'AuxLoss', 'Total')) for i in pbar: imgs, targets, _, _ = data_iter.next() if len(targets) == 0: continue imgs = imgs.to(device).float( ) / 255.0 # uint8 to float32, 0 - 255 to 0.0 - 1.0 targets = targets.to(device) with torch.no_grad(): _ = origin_model(imgs) _, pruning_pred = pruning_model(imgs) pruning_loss, _ = compute_loss(pruning_pred, targets, pruning_model) hook_util.cat_to_gpu0() mse_loss = torch.zeros(1, device=device) aux_pred = aux_model(hook_util.prune_features['gpu0'][1], targets) aux_loss = compute_loss_for_DCP(aux_pred, targets) mse_loss += MSE(hook_util.prune_features['gpu0'][0], hook_util.origin_features['gpu0'][0]) loss = hyp['joint_loss'] * mse_loss + aux_loss + 0 * pruning_loss loss.backward() mem = torch.cuda.memory_cached() / 1E9 if torch.cuda.is_available( ) else 0 s = ('%10s' * 3 + '%10.3g' * 3) % ( 'Prune ' + select_layer, '%.3gG' % mem, '%g/%g' % (i_k, retain_channels_num), hyp['joint_loss'] * mse_loss, aux_loss, loss) pbar.set_description(s) # if (i + 1) % 10 == 0: # logger.info(('%10s' * 3 + '%10.3g' * 3) % # ('Prune' + select_layer, str(i_k), '%g/%g' % (i, n_iter), hyp['joint_loss'] * mse_loss, # aux_loss, loss)) hook_util.clean_hook_out() grad = pruning_model.module.module_list[int( select_layer)].MaskConv2d.weight.grad.detach()**2 grad = grad.sum((2, 3)).sqrt().sum(0) if i_k == 0: pruning_model.module.module_list[int( select_layer)].MaskConv2d.selected_channels_mask[:] = 1e-5 if select_layer in aux_util.sync_guide.keys(): sync_layer = aux_util.sync_guide[select_layer] pruning_model.module.module_list[int( sync_layer)].MaskConv2d.selected_channels_mask[( -1 * aux_util.layer_info[select_layer]["in_channels"] ):] = 1e-5 selected_channels_mask = pruning_model.module.module_list[int( select_layer)].MaskConv2d.selected_channels_mask _, indices = torch.topk(grad * (1 - selected_channels_mask), 1) pruning_model.module.module_list[int( select_layer)].MaskConv2d.selected_channels_mask[indices] = 1 if select_layer in aux_util.sync_guide.keys(): pruning_model.module.module_list[int( sync_layer)].MaskConv2d.selected_channels_mask[-( aux_util.layer_info[select_layer]["in_channels"] - indices)] = 1 pruning_model.zero_grad() pbar = tqdm(range(n_iter), total=n_iter) print(('\n' + '%10s' * 6) % ('Stage', 'gpu_mem', 'channels', 'MSELoss', 'AuxLoss', 'Total')) for i in pbar: imgs, targets, _, _ = data_iter.next() if len(targets) == 0: continue imgs = imgs.to(device).float( ) / 255.0 # uint8 to float32, 0 - 255 to 0.0 - 1.0 targets = targets.to(device) with torch.no_grad(): _ = origin_model(imgs) _, pruning_pred = pruning_model(imgs) pruning_loss, _ = compute_loss(pruning_pred, targets, pruning_model) hook_util.cat_to_gpu0() mse_loss = torch.zeros(1, device=device) aux_pred = aux_model(hook_util.prune_features['gpu0'][1], targets) aux_loss = compute_loss_for_DCP(aux_pred, targets) mse_loss += MSE(hook_util.prune_features['gpu0'][0], hook_util.origin_features['gpu0'][0]) loss = hyp[ 'joint_loss'] * mse_loss + aux_loss_scalar * aux_loss + 0 * pruning_loss loss.backward() if i % accumulate == 0: solve_sub_problem_optimizer.step() solve_sub_problem_optimizer.zero_grad() mem = torch.cuda.memory_cached() / 1E9 if torch.cuda.is_available( ) else 0 mloss = (mloss * i + torch.cat([hyp['joint_loss'] * mse_loss, aux_loss, loss ]).detach()) / (i + 1) s = ('%10s' * 3 + '%10.3g' * 3) % ( 'SubProm ' + select_layer, '%.3gG' % mem, '%g/%g' % (i_k, retain_channels_num), *mloss) pbar.set_description(s) if (i + 1) % n_iter == 0: logger.info(('%10s' * 3 + '%10.3g' * 3) % ('SubPro' + select_layer, str(i_k), '%g/%g' % (i, n_iter), *mloss)) hook_util.clean_hook_out() for handle in handles: handle.remove() greedy_indices = pruning_model.module.module_list[int( select_layer)].MaskConv2d.selected_channels_mask < 1 pruning_model.module.module_list[int( select_layer)].MaskConv2d.selected_channels_mask[greedy_indices] = 0 res, _ = test.test(prune_cfg, data, batch_size=batch_size * 2, img_size=416, model=pruning_model, conf_thres=0.1, iou_thres=0.5, save_json=False, dataloader=None) chkpt = torch.load(progress_chkpt, map_location=device) chkpt['current_layer'] = aux_util.next_prune_layer(select_layer) chkpt['epoch'] = -1 chkpt['model'] = pruning_model.module.state_dict() if type( pruning_model ) is nn.parallel.DistributedDataParallel else pruning_model.state_dict() chkpt['optimizer'] = None torch.save(chkpt, progress_chkpt) torch.save(chkpt, last) del chkpt with open(progress_result, 'a') as f: f.write(('%10s' * 2 + '%10.3g' * 7) % ('Pruning ' + select_layer, str(aux_util.layer_info[select_layer]['in_channels']) + '->' + str(retain_channels_num), *mloss, *res[:4]) + '\n') torch.cuda.empty_cache()
def train_aux_for_LCP(cfg, backbone, neck, data_loader, weights, aux_weight, hyp, device, resume, epochs): init_seeds() batch_size = data_loader.batch_size accumulate = 64 // batch_size model = Darknet(cfg).to(device) model_chkpt = torch.load(weights, map_location=device) model.load_state_dict(model_chkpt['model'], strict=True) del model_chkpt aux_util = AuxNetUtils(model, hyp, backbone, neck, strategy="LCP") hook_util = HookUtils() start_epoch = 0 aux_model_list = [] pg = [] for layer in aux_util.aux_in_layer: aux_model = aux_util.creat_aux_model(layer) aux_model.to(device) for v in aux_model.parameters(): pg += [v] aux_model_list.append(aux_model) optimizer = optim.SGD(pg, lr=hyp['lr0'], momentum=hyp['momentum'], nesterov=True) del pg if resume: chkpt = torch.load(aux_weight, map_location=device) for i, layer in enumerate(aux_util.aux_in_layer): aux_model_list[i].load_state_dict(chkpt['aux_in{}'.format(layer)], strict=True) if chkpt['optimizer'] is not None: optimizer.load_state_dict(chkpt['optimizer']) start_epoch = chkpt['epoch'] + 1 scheduler = lr_scheduler.MultiStepLR( optimizer, milestones=[epochs // 3, 2 * epochs // 3], gamma=0.1) scheduler.last_epoch = start_epoch - 1 handles = [] # 结束训练后handle需要回收 for name, child in model.module_list.named_children(): if name in aux_util.aux_in_layer: handles.append( child.register_forward_hook(hook_util.hook_origin_output)) if device.type != 'cpu' and torch.cuda.device_count() > 1: model = nn.parallel.DistributedDataParallel( model, find_unused_parameters=True) model.yolo_layers = model.module.yolo_layers nb = len(data_loader) model.nc = 80 model.hyp = hyp model.arc = 'default' print('Starting training for %g epochs...' % epochs) for epoch in range(start_epoch, epochs): for aux_model in aux_model_list: aux_model.train() print(('\n' + '%10s' * 6) % ('Stage', 'Epoch', 'gpu_mem', 'AuxID', 'cls', 'targets')) # -----------------start batch----------------- pbar = tqdm(enumerate(data_loader), total=nb) model.train() for i, (imgs, targets, _, _) in pbar: if len(targets) == 0: continue ni = i + nb * epoch imgs = imgs.to(device).float( ) / 255.0 # uint8 to float32, 0 - 255 to 0.0 - 1.0 targets = targets.to(device) with torch.no_grad(): prediction = model(imgs) hook_util.cat_to_gpu0() for aux_idx, aux_model in enumerate(aux_model_list): pred, loc_loss = aux_model( hook_util.origin_features['gpu0'][aux_idx], targets, prediction) loss = compute_loss_for_LCP(pred, loc_loss, targets) loss *= batch_size / 64 loss.backward() mem = torch.cuda.memory_cached( ) / 1E9 if torch.cuda.is_available() else 0 # (GB) s = ('%10s' * 3 + '%10.3g' * 3) % ('Train Aux', '%g/%g' % (epoch, epochs - 1), '%.3gG' % mem, aux_idx, loss, len(targets)) pbar.set_description(s) # 每个batch后要把hook_out内容清除 hook_util.clean_hook_out() if ni % accumulate == 0: optimizer.step() optimizer.zero_grad() # -----------------end batches----------------- scheduler.step() final_epoch = epoch + 1 == epochs chkpt = { 'epoch': epoch, 'optimizer': None if final_epoch else optimizer.state_dict() } for i, layer in enumerate(aux_util.aux_in_layer): chkpt['aux_in{}'.format(layer)] = aux_model_list[i].state_dict() torch.save(chkpt, aux_weight) torch.save(chkpt, "../weights/LCP/aux-coco.pt") del chkpt with open("./LCP/aux_result.txt", 'a') as f: f.write(s + '\n') # 最后要把hook全部删除 for handle in handles: handle.remove() torch.cuda.empty_cache()
# %% # Combine sample sets back together samples = positive_samp # + negative_samp # %% # Load the model img_size = (352, 608) device = 'cuda:0' arc = 'default' cfg = 'cfg/yolov3-tiny-anchors.cfg' weights = 'weights/best.pt' device = torch_utils.select_device(device, apex=False, batch_size=64) model = Darknet(cfg, img_size=img_size, arc=arc).to(device) model.arc = 'default' model.nc = 1 # num classes model.hyp = hyp d = torch.load(weights, map_location=device) m = d['model'] model.load_state_dict(m) # Build the paths and pass them to the FastAI ObjectItemList posix_paths = json_to_paths(samples) lst = ObjectItemList(posix_paths, label_cls=YoloCategoryList) YoloCategoryList.anchors = [ model.module_list[l].anchors for l in model.yolo_layers ] YoloCategoryList.img_size = img_size # %%