def create_optimizer(net,
                     name,
                     learning_rate,
                     weight_decay,
                     momentum=0,
                     fp16_loss_scale=None,
                     optimizer_state=None):
    use_fp16 = fp16_loss_scale is not None
    if use_fp16:
        from apex import fp16_utils
        net = fp16_utils.network_to_half(net)

    # device = choose_device(device)
    # print('use', device)

    # net = net.to(device)

    # optimizer
    parameters = [p for p in net.parameters() if p.requires_grad]
    print('N of parameters', len(parameters))

    if name == 'sgd':
        optimizer = optim.SGD(parameters,
                              lr=learning_rate,
                              momentum=momentum,
                              weight_decay=weight_decay)
    elif name == 'adamw':
        from .adamw import AdamW
        optimizer = AdamW(parameters,
                          lr=learning_rate,
                          weight_decay=weight_decay)
    elif name == 'adam':
        optimizer = optim.Adam(parameters,
                               lr=learning_rate,
                               weight_decay=weight_decay)
    else:
        raise NotImplementedError(name)

    if use_fp16:
        from apex import fp16_utils
        if fp16_loss_scale == 0:
            opt_args = dict(dynamic_loss_scale=True)
        else:
            opt_args = dict(static_loss_scale=fp16_loss_scale)
        print('FP16_Optimizer', opt_args)
        optimizer = fp16_utils.FP16_Optimizer(optimizer, **opt_args)
    else:
        optimizer.backward = lambda loss: loss.backward()

    if optimizer_state:
        if use_fp16 and 'optimizer_state_dict' not in optimizer_state:
            # resume FP16_Optimizer.optimizer only
            optimizer.optimizer.load_state_dict(optimizer_state)
        elif use_fp16 and 'optimizer_state_dict' in optimizer_state:
            # resume optimizer from FP16_Optimizer.optimizer
            optimizer.load_state_dict(optimizer_state['optimizer_state_dict'])
        else:
            optimizer.load_state_dict(optimizer_state)

    return net, optimizer
Exemple #2
0
def wrap_optimizer(optimizer):
    if _FP16_ENABLED:
        if _USE_FP16_OPTIMIZER:
            return fp16_utils.FP16_Optimizer(optimizer, dynamic_loss_scale=True)
        else:
            return _amp_handle.wrap_optimizer(optimizer)
    else:
        return optimizer
    if device == torch.device('cpu'):
        param = torch.load(pretrained_path, map_location='cpu'
                           )  # parameters saved in checkpoint via model_path
    else:
        param = torch.load(
            pretrained_path)  # parameters saved in checkpoint via model_path
    #param = torch.load(pretrained_path)
    model.load_state_dict(param)
    del param

# fp16
if fp16:
    from apex import fp16_utils
    model = fp16_utils.BN_convert_float(model.half())
    optimizer = fp16_utils.FP16_Optimizer(optimizer,
                                          verbose=False,
                                          dynamic_loss_scale=True)
    logger.info('Apply fp16')

# Restore model
if resume:
    model_path = output_dir.joinpath(f'model_tmp.pth')
    logger.info(f'Resume from {model_path}')
    param = torch.load(model_path)
    model.load_state_dict(param)
    del param
    opt_path = output_dir.joinpath(f'opt_tmp.pth')
    param = torch.load(opt_path)
    optimizer.load_state_dict(param)
    del param
def train(model,
          state,
          path,
          annotations,
          val_path,
          val_annotations,
          resize,
          max_size,
          jitter,
          batch_size,
          iterations,
          val_iterations,
          mixed_precision,
          lr,
          warmup,
          milestones,
          gamma,
          is_master=True,
          world=1,
          use_dali=True,
          verbose=True,
          metrics_url=None,
          logdir=None):
    'Train the model on the given dataset'

    # Prepare dataset
    if verbose: print('Preparing dataset...')
    data_iterator = (DaliDataIterator if use_dali else DataIterator)(
        path,
        jitter,
        max_size,
        batch_size,
        model.stride,
        world,
        annotations,
        training=True)
    if verbose: print(data_iterator)

    # Prepare model
    nn_model = model
    model = convert_fixedbn_model(model)
    if torch.cuda.is_available():
        model = model.cuda()
    if mixed_precision:
        model = fp16_utils.BN_convert_float(model.half())
    if world > 1:
        model = DistributedDataParallel(model, delay_allreduce=True)
    model.train()

    # Setup optimizer and schedule
    optimizer = SGD(model.parameters(),
                    lr=lr,
                    weight_decay=0.0001,
                    momentum=0.9)
    if mixed_precision:
        optimizer = fp16_utils.FP16_Optimizer(optimizer,
                                              static_loss_scale=128.,
                                              verbose=False)
    if 'optimizer' in state:
        optimizer.load_state_dict(state['optimizer'])

    def schedule(train_iter):
        if warmup and train_iter <= warmup:
            return 0.9 * train_iter / warmup + 0.1
        return gamma**len([m for m in milestones if m <= train_iter])

    scheduler = LambdaLR(optimizer.optimizer if mixed_precision else optimizer,
                         schedule)

    if verbose:
        print('    device: {} {}'.format(
            world, 'cpu' if not torch.cuda.is_available() else
            'gpu' if world == 1 else 'gpus'))
        print('    batch: {}, precision: {}'.format(
            batch_size, 'mixed' if mixed_precision else 'full'))
        print('Training model for {} iterations...'.format(iterations))

    # Create TensorBoard writer
    if logdir is not None:
        from tensorboardX import SummaryWriter
        if is_master and verbose:
            print('Writing TensorBoard logs to: {}'.format(logdir))
        writer = SummaryWriter(log_dir=logdir)

    profiler = Profiler(['train', 'fw', 'bw'])
    iteration = state.get('iteration', 0)
    while iteration < iterations:
        cls_losses, box_losses = [], []
        for i, (data, target) in enumerate(data_iterator):
            scheduler.step(iteration)

            # Forward pass
            profiler.start('fw')
            if mixed_precision:
                data = data.half()
            optimizer.zero_grad()
            cls_loss, box_loss = model([data, target])
            del data
            profiler.stop('fw')

            # Backward pass
            profiler.start('bw')
            if mixed_precision: optimizer.backward(cls_loss + box_loss)
            else: (cls_loss + box_loss).backward()
            optimizer.step()

            # Reduce all losses
            cls_loss, box_loss = cls_loss.mean().clone(), box_loss.mean(
            ).clone()
            if world > 1:
                torch.distributed.all_reduce(cls_loss)
                torch.distributed.all_reduce(box_loss)
                cls_loss /= world
                box_loss /= world
            if is_master:
                cls_losses.append(cls_loss)
                box_losses.append(box_loss)

            if is_master and not isfinite(cls_loss + box_loss):
                raise RuntimeError('Loss is diverging!\n{}'.format(
                    'Try lowering the learning rate.'))

            del cls_loss, box_loss
            profiler.stop('bw')

            iteration += 1
            profiler.bump('train')
            if is_master and (profiler.totals['train'] > 60
                              or iteration == iterations):
                focal_loss = torch.stack(list(cls_losses)).mean().item()
                box_loss = torch.stack(list(box_losses)).mean().item()
                learning_rate = optimizer.param_groups[0]['lr']
                if verbose:
                    msg = '[{:{len}}/{}]'.format(iteration,
                                                 iterations,
                                                 len=len(str(iterations)))
                    msg += ' focal loss: {:.3f}'.format(focal_loss)
                    msg += ', box loss: {:.3f}'.format(box_loss)
                    msg += ', {:.3f}s/{}-batch'.format(profiler.means['train'],
                                                       batch_size)
                    msg += ' (fw: {:.3f}s, bw: {:.3f}s)'.format(
                        profiler.means['fw'], profiler.means['bw'])
                    msg += ', {:.1f} im/s'.format(batch_size /
                                                  profiler.means['train'])
                    msg += ', lr: {:.2g}'.format(learning_rate)
                    print(msg, flush=True)

                if logdir is not None:
                    writer.add_scalar('focal_loss', focal_loss, iteration)
                    writer.add_scalar('box_loss', box_loss, iteration)
                    writer.add_scalar('learning_rate', learning_rate,
                                      iteration)
                    del box_loss, focal_loss

                if metrics_url:
                    post_metrics(
                        metrics_url, {
                            'focal loss': mean(cls_losses),
                            'box loss': mean(box_losses),
                            'im_s': batch_size / profiler.means['train'],
                            'lr': learning_rate
                        })

                # Save model weights
                state.update({
                    'iteration': iteration,
                    'optimizer': optimizer.state_dict(),
                    'scheduler': scheduler.state_dict(),
                })
                with ignore_sigint():
                    nn_model.save(state)

                profiler.reset()
                del cls_losses[:], box_losses[:]

            if val_annotations and (iteration == iterations
                                    or iteration % val_iterations == 0):
                infer(nn_model,
                      val_path,
                      None,
                      resize,
                      max_size,
                      batch_size,
                      annotations=val_annotations,
                      mixed_precision=mixed_precision,
                      is_master=is_master,
                      world=world,
                      use_dali=use_dali,
                      verbose=False)
                model.train()

            if iteration == iterations:
                break

    if logdir is not None:
        writer.close()