コード例 #1
0
def run(config):
    # build hooks
    hooks = build_hooks(config)

    # build model
    model = build_model(config, hooks)

    # build loss
    loss = build_loss(config)
    loss_fn = hooks.loss_fn
    hooks.loss_fn = lambda **kwargs: loss_fn(loss_fn=loss, **kwargs)

    # load checkpoint
    checkpoint = config.checkpoint
    last_epoch, step = dlcommon.utils.load_checkpoint(model, None, checkpoint)

    # build datasets
    dataloaders = build_dataloaders(config)

    model = model.cuda()
    if torch.cuda.device_count() > 1:
        model = torch.nn.DataParallel(model)

    # train loop
    evaluate(config=config, model=model, dataloaders=dataloaders, hooks=hooks)
コード例 #2
0
def run(config):
    # prepare directories
    prepare_directories(config)

    # build hooks
    hooks = build_hooks(config)

    # build model
    model = build_model(config, hooks)
    # build loss
    loss = build_loss(config)
    loss_fn = hooks.loss_fn
    hooks.loss_fn = lambda **kwargs: loss_fn(loss_fn=loss, **kwargs)
    
    # build optimizer
    params = model.parameters()
    optimizer = build_optimizer(config, params=params)

    model = model.cuda()
    # load checkpoint
    checkpoint = dlcommon.utils.get_initial_checkpoint(config)
    if checkpoint is not None:
        last_epoch, step = dlcommon.utils.load_checkpoint(model, optimizer, checkpoint)
        print('epoch, step:', last_epoch, step)
    else:
        last_epoch, step = -1, -1

    model, optimizer = to_data_parallel(config, model, optimizer)

    # build scheduler
    scheduler = build_scheduler(config, optimizer=optimizer, 
                                last_epoch=last_epoch)

    # build datasets
    dataloaders = build_dataloaders(config)

    # build summary writer
    writer = SummaryWriter(logdir=config.train.dir)
    logger_fn = hooks.logger_fn
    hooks.logger_fn = lambda **kwargs: logger_fn(writer=writer, **kwargs)

    # train loop
    train(config=config,
          model=model,
          optimizer=optimizer,
          scheduler=scheduler,
          dataloaders=dataloaders,
          hooks=hooks,
          last_epoch=last_epoch+1)
コード例 #3
0
def run(config):
    # build hooks
    hooks = build_hooks(config)

    # build model
    model = build_model(config, hooks)
    model = model.cuda()
    if torch.cuda.device_count() > 1:
        model = torch.nn.DataParallel(model)

    # load checkpoint
    checkpoint = os.path.join(config.train.dir, config.checkpoint)
    last_epoch, step = dlcommon.utils.load_checkpoint(model, None, checkpoint)
    print(f'last_epoch:{last_epoch}')

    # build datasets
    dataloaders = build_dataloaders(config)

    # calculation method for anomaly score
    if config.loss.name == 'SSIMLoss':
        from dlcommon.losses import SSIMLoss
        score_fn = SSIMLoss(size_average=False)
    elif config.loss.name == 'MSELoss':
        from torch.nn import MSELoss

        class MSEInstances:
            def __init__(self):
                self.mse_elements = MSELoss(reduction='none')

            def __call__(self, input, target):
                loss_elements = self.mse_elements(input, target)
                loss_instances = loss_elements.mean(axis=(1, 2, 3))
                return loss_instances

        score_fn = MSEInstances()

    # train loop
    inference(config=config,
              model=model,
              dataloaders=dataloaders,
              hooks=hooks,
              score_fn=score_fn)
コード例 #4
0
def run(config):
    # build hooks
    hooks = build_hooks(config)

    # build model
    model = build_model(config, hooks)
    model = model.cuda()
    if torch.cuda.device_count() > 1:
        model = torch.nn.DataParallel(model)

    # load checkpoint
    checkpoint = config.checkpoint
    last_epoch, step = dlcommon.utils.load_checkpoint(model, None, checkpoint)
    print(f'last_epoch:{last_epoch}')

    # build datasets
    dataloaders = build_dataloaders(config)

    # train loop
    inference(config=config,
              model=model,
              dataloaders=dataloaders,
              hooks=hooks)
コード例 #5
0
def run(config):
    # build hooks
    hooks = build_hooks(config)

    # build model
    model = build_model(config, hooks, member='gan')
    G = model.G
    D = model.D
    E = build_model(config, hooks, member='encoder')

    def _freeze_model(_model):
        for param in _model.parameters():
            param.requires_grad = False

    _freeze_model(G)
    _freeze_model(D)
    _freeze_model(E)

    G = G.cuda()
    D = D.cuda()
    E = E.cuda()
    if torch.cuda.device_count() > 1:
        G = torch.nn.DataParallel(G)
        D = torch.nn.DataParallel(D)
        E = torch.nn.DataParallel(E)

    # load checkpoint
    def load_from_checkpoint(_model, checkpoint_name):
        checkpoint = os.path.join(config.train.dir, checkpoint_name)
        last_epoch, step = dlcommon.utils.load_checkpoint(
            _model, None, checkpoint)

    load_from_checkpoint(G, config.checkpoint.g)
    load_from_checkpoint(D, config.checkpoint.d)
    load_from_checkpoint(E, config.checkpoint.e)

    # build datasets
    dataloaders = build_dataloaders(config)

    # calculation method for anomaly score
    from torch.nn import MSELoss

    class MSEInstances:
        def __init__(self):
            self.mse_elements = MSELoss(reduction='none')

        def __call__(self, input, target):
            loss_elements = self.mse_elements(input, target)
            loss_elements = torch.flatten(loss_elements, start_dim=1)
            loss_instances = loss_elements.mean(axis=1)
            return loss_instances

    score_fn = MSEInstances()

    # train loop
    inference(config=config,
              G=G,
              D=D,
              E=E,
              dataloaders=dataloaders,
              hooks=hooks,
              score_fn=score_fn)
コード例 #6
0
def run(config):
    # prepare directories
    prepare_directories(config)

    # build hooks
    hooks = build_hooks(config)

    # build model
    model = build_model(config, hooks)
    # build loss
    loss = build_loss(config)
    loss_fn = hooks.loss_fn
    hooks.loss_fn = lambda **kwargs: loss_fn(loss_fn=loss, **kwargs)

    # build optimizer
    if 'no_bias_decay' in config.train and config.train.no_bias_decay:
        if 'encoder_lr_ratio' in config.train:
            encoder_lr_ratio = config.train.encoder_lr_ratio
            group_decay_encoder, group_no_decay_encoder = group_weight(
                model.encoder)
            base_lr = config.optimizer.params.lr
            params = [{
                'params': model.product.parameters(),
                'lr': base_lr
            }, {
                'params': model.fc.parameters(),
                'lr': base_lr
            }, {
                'params': group_decay_encoder,
                'lr': base_lr * encoder_lr_ratio
            }, {
                'params': group_no_decay_encoder,
                'lr': base_lr * encoder_lr_ratio,
                'weight_decay': 0.0
            }]
        else:
            group_decay, group_no_decay = group_weight(model)
            params = [{
                'params': group_decay
            }, {
                'params': group_no_decay,
                'weight_decay': 0.0
            }]
    elif 'encoder_lr_ratio' in config.train:
        denom = config.train.encoder_lr_ratio
        base_lr = config.optimizer.params.lr
        params = [{
            'params': model.encoder.parameters(),
            'lr': base_lr * denom
        }, {
            'params': model.fc.parameters(),
            'lr': base_lr
        }, {
            'params': model.product.parameters(),
            'lr': base_lr
        }]
    else:
        params = model.parameters()
    optimizer = build_optimizer(config, params=params)

    model = model.cuda()
    # load checkpoint
    checkpoint = dlcommon.utils.get_initial_checkpoint(config)
    if checkpoint is not None:
        last_epoch, step = dlcommon.utils.load_checkpoint(
            model, optimizer, checkpoint)
        print('epoch, step:', last_epoch, step)
    else:
        last_epoch, step = -1, -1

    model, optimizer = to_data_parallel(config, model, optimizer)

    # build scheduler
    scheduler = build_scheduler(config,
                                optimizer=optimizer,
                                last_epoch=last_epoch)

    # build datasets
    dataloaders = build_dataloaders(config)

    # build summary writer
    writer = SummaryWriter(logdir=config.train.dir)
    logger_fn = hooks.logger_fn
    hooks.logger_fn = lambda **kwargs: logger_fn(writer=writer, **kwargs)

    # train loop
    train(config=config,
          model=model,
          optimizer=optimizer,
          scheduler=scheduler,
          dataloaders=dataloaders,
          hooks=hooks,
          last_epoch=last_epoch + 1)