def create_optimizer(net, name, learning_rate, weight_decay, momentum=0, fp16_loss_scale=None, optimizer_state=None): use_fp16 = fp16_loss_scale is not None if use_fp16: from apex import fp16_utils net = fp16_utils.network_to_half(net) # device = choose_device(device) # print('use', device) # net = net.to(device) # optimizer parameters = [p for p in net.parameters() if p.requires_grad] print('N of parameters', len(parameters)) if name == 'sgd': optimizer = optim.SGD(parameters, lr=learning_rate, momentum=momentum, weight_decay=weight_decay) elif name == 'adamw': from .adamw import AdamW optimizer = AdamW(parameters, lr=learning_rate, weight_decay=weight_decay) elif name == 'adam': optimizer = optim.Adam(parameters, lr=learning_rate, weight_decay=weight_decay) else: raise NotImplementedError(name) if use_fp16: from apex import fp16_utils if fp16_loss_scale == 0: opt_args = dict(dynamic_loss_scale=True) else: opt_args = dict(static_loss_scale=fp16_loss_scale) print('FP16_Optimizer', opt_args) optimizer = fp16_utils.FP16_Optimizer(optimizer, **opt_args) else: optimizer.backward = lambda loss: loss.backward() if optimizer_state: if use_fp16 and 'optimizer_state_dict' not in optimizer_state: # resume FP16_Optimizer.optimizer only optimizer.optimizer.load_state_dict(optimizer_state) elif use_fp16 and 'optimizer_state_dict' in optimizer_state: # resume optimizer from FP16_Optimizer.optimizer optimizer.load_state_dict(optimizer_state['optimizer_state_dict']) else: optimizer.load_state_dict(optimizer_state) return net, optimizer
def wrap_optimizer(optimizer): if _FP16_ENABLED: if _USE_FP16_OPTIMIZER: return fp16_utils.FP16_Optimizer(optimizer, dynamic_loss_scale=True) else: return _amp_handle.wrap_optimizer(optimizer) else: return optimizer
if device == torch.device('cpu'): param = torch.load(pretrained_path, map_location='cpu' ) # parameters saved in checkpoint via model_path else: param = torch.load( pretrained_path) # parameters saved in checkpoint via model_path #param = torch.load(pretrained_path) model.load_state_dict(param) del param # fp16 if fp16: from apex import fp16_utils model = fp16_utils.BN_convert_float(model.half()) optimizer = fp16_utils.FP16_Optimizer(optimizer, verbose=False, dynamic_loss_scale=True) logger.info('Apply fp16') # Restore model if resume: model_path = output_dir.joinpath(f'model_tmp.pth') logger.info(f'Resume from {model_path}') param = torch.load(model_path) model.load_state_dict(param) del param opt_path = output_dir.joinpath(f'opt_tmp.pth') param = torch.load(opt_path) optimizer.load_state_dict(param) del param
def train(model, state, path, annotations, val_path, val_annotations, resize, max_size, jitter, batch_size, iterations, val_iterations, mixed_precision, lr, warmup, milestones, gamma, is_master=True, world=1, use_dali=True, verbose=True, metrics_url=None, logdir=None): 'Train the model on the given dataset' # Prepare dataset if verbose: print('Preparing dataset...') data_iterator = (DaliDataIterator if use_dali else DataIterator)( path, jitter, max_size, batch_size, model.stride, world, annotations, training=True) if verbose: print(data_iterator) # Prepare model nn_model = model model = convert_fixedbn_model(model) if torch.cuda.is_available(): model = model.cuda() if mixed_precision: model = fp16_utils.BN_convert_float(model.half()) if world > 1: model = DistributedDataParallel(model, delay_allreduce=True) model.train() # Setup optimizer and schedule optimizer = SGD(model.parameters(), lr=lr, weight_decay=0.0001, momentum=0.9) if mixed_precision: optimizer = fp16_utils.FP16_Optimizer(optimizer, static_loss_scale=128., verbose=False) if 'optimizer' in state: optimizer.load_state_dict(state['optimizer']) def schedule(train_iter): if warmup and train_iter <= warmup: return 0.9 * train_iter / warmup + 0.1 return gamma**len([m for m in milestones if m <= train_iter]) scheduler = LambdaLR(optimizer.optimizer if mixed_precision else optimizer, schedule) if verbose: print(' device: {} {}'.format( world, 'cpu' if not torch.cuda.is_available() else 'gpu' if world == 1 else 'gpus')) print(' batch: {}, precision: {}'.format( batch_size, 'mixed' if mixed_precision else 'full')) print('Training model for {} iterations...'.format(iterations)) # Create TensorBoard writer if logdir is not None: from tensorboardX import SummaryWriter if is_master and verbose: print('Writing TensorBoard logs to: {}'.format(logdir)) writer = SummaryWriter(log_dir=logdir) profiler = Profiler(['train', 'fw', 'bw']) iteration = state.get('iteration', 0) while iteration < iterations: cls_losses, box_losses = [], [] for i, (data, target) in enumerate(data_iterator): scheduler.step(iteration) # Forward pass profiler.start('fw') if mixed_precision: data = data.half() optimizer.zero_grad() cls_loss, box_loss = model([data, target]) del data profiler.stop('fw') # Backward pass profiler.start('bw') if mixed_precision: optimizer.backward(cls_loss + box_loss) else: (cls_loss + box_loss).backward() optimizer.step() # Reduce all losses cls_loss, box_loss = cls_loss.mean().clone(), box_loss.mean( ).clone() if world > 1: torch.distributed.all_reduce(cls_loss) torch.distributed.all_reduce(box_loss) cls_loss /= world box_loss /= world if is_master: cls_losses.append(cls_loss) box_losses.append(box_loss) if is_master and not isfinite(cls_loss + box_loss): raise RuntimeError('Loss is diverging!\n{}'.format( 'Try lowering the learning rate.')) del cls_loss, box_loss profiler.stop('bw') iteration += 1 profiler.bump('train') if is_master and (profiler.totals['train'] > 60 or iteration == iterations): focal_loss = torch.stack(list(cls_losses)).mean().item() box_loss = torch.stack(list(box_losses)).mean().item() learning_rate = optimizer.param_groups[0]['lr'] if verbose: msg = '[{:{len}}/{}]'.format(iteration, iterations, len=len(str(iterations))) msg += ' focal loss: {:.3f}'.format(focal_loss) msg += ', box loss: {:.3f}'.format(box_loss) msg += ', {:.3f}s/{}-batch'.format(profiler.means['train'], batch_size) msg += ' (fw: {:.3f}s, bw: {:.3f}s)'.format( profiler.means['fw'], profiler.means['bw']) msg += ', {:.1f} im/s'.format(batch_size / profiler.means['train']) msg += ', lr: {:.2g}'.format(learning_rate) print(msg, flush=True) if logdir is not None: writer.add_scalar('focal_loss', focal_loss, iteration) writer.add_scalar('box_loss', box_loss, iteration) writer.add_scalar('learning_rate', learning_rate, iteration) del box_loss, focal_loss if metrics_url: post_metrics( metrics_url, { 'focal loss': mean(cls_losses), 'box loss': mean(box_losses), 'im_s': batch_size / profiler.means['train'], 'lr': learning_rate }) # Save model weights state.update({ 'iteration': iteration, 'optimizer': optimizer.state_dict(), 'scheduler': scheduler.state_dict(), }) with ignore_sigint(): nn_model.save(state) profiler.reset() del cls_losses[:], box_losses[:] if val_annotations and (iteration == iterations or iteration % val_iterations == 0): infer(nn_model, val_path, None, resize, max_size, batch_size, annotations=val_annotations, mixed_precision=mixed_precision, is_master=is_master, world=world, use_dali=use_dali, verbose=False) model.train() if iteration == iterations: break if logdir is not None: writer.close()