def run(config): # build hooks hooks = build_hooks(config) # build model model = build_model(config, hooks) # build loss loss = build_loss(config) loss_fn = hooks.loss_fn hooks.loss_fn = lambda **kwargs: loss_fn(loss_fn=loss, **kwargs) # load checkpoint checkpoint = config.checkpoint last_epoch, step = dlcommon.utils.load_checkpoint(model, None, checkpoint) # build datasets dataloaders = build_dataloaders(config) model = model.cuda() if torch.cuda.device_count() > 1: model = torch.nn.DataParallel(model) # train loop evaluate(config=config, model=model, dataloaders=dataloaders, hooks=hooks)
def run(config): # prepare directories prepare_directories(config) # build hooks hooks = build_hooks(config) # build model model = build_model(config, hooks) # build loss loss = build_loss(config) loss_fn = hooks.loss_fn hooks.loss_fn = lambda **kwargs: loss_fn(loss_fn=loss, **kwargs) # build optimizer params = model.parameters() optimizer = build_optimizer(config, params=params) model = model.cuda() # load checkpoint checkpoint = dlcommon.utils.get_initial_checkpoint(config) if checkpoint is not None: last_epoch, step = dlcommon.utils.load_checkpoint(model, optimizer, checkpoint) print('epoch, step:', last_epoch, step) else: last_epoch, step = -1, -1 model, optimizer = to_data_parallel(config, model, optimizer) # build scheduler scheduler = build_scheduler(config, optimizer=optimizer, last_epoch=last_epoch) # build datasets dataloaders = build_dataloaders(config) # build summary writer writer = SummaryWriter(logdir=config.train.dir) logger_fn = hooks.logger_fn hooks.logger_fn = lambda **kwargs: logger_fn(writer=writer, **kwargs) # train loop train(config=config, model=model, optimizer=optimizer, scheduler=scheduler, dataloaders=dataloaders, hooks=hooks, last_epoch=last_epoch+1)
def run(config): # build hooks hooks = build_hooks(config) # build model model = build_model(config, hooks) model = model.cuda() if torch.cuda.device_count() > 1: model = torch.nn.DataParallel(model) # load checkpoint checkpoint = os.path.join(config.train.dir, config.checkpoint) last_epoch, step = dlcommon.utils.load_checkpoint(model, None, checkpoint) print(f'last_epoch:{last_epoch}') # build datasets dataloaders = build_dataloaders(config) # calculation method for anomaly score if config.loss.name == 'SSIMLoss': from dlcommon.losses import SSIMLoss score_fn = SSIMLoss(size_average=False) elif config.loss.name == 'MSELoss': from torch.nn import MSELoss class MSEInstances: def __init__(self): self.mse_elements = MSELoss(reduction='none') def __call__(self, input, target): loss_elements = self.mse_elements(input, target) loss_instances = loss_elements.mean(axis=(1, 2, 3)) return loss_instances score_fn = MSEInstances() # train loop inference(config=config, model=model, dataloaders=dataloaders, hooks=hooks, score_fn=score_fn)
def run(config): # build hooks hooks = build_hooks(config) # build model model = build_model(config, hooks) model = model.cuda() if torch.cuda.device_count() > 1: model = torch.nn.DataParallel(model) # load checkpoint checkpoint = config.checkpoint last_epoch, step = dlcommon.utils.load_checkpoint(model, None, checkpoint) print(f'last_epoch:{last_epoch}') # build datasets dataloaders = build_dataloaders(config) # train loop inference(config=config, model=model, dataloaders=dataloaders, hooks=hooks)
def run(config): # build hooks hooks = build_hooks(config) # build model model = build_model(config, hooks, member='gan') G = model.G D = model.D E = build_model(config, hooks, member='encoder') def _freeze_model(_model): for param in _model.parameters(): param.requires_grad = False _freeze_model(G) _freeze_model(D) _freeze_model(E) G = G.cuda() D = D.cuda() E = E.cuda() if torch.cuda.device_count() > 1: G = torch.nn.DataParallel(G) D = torch.nn.DataParallel(D) E = torch.nn.DataParallel(E) # load checkpoint def load_from_checkpoint(_model, checkpoint_name): checkpoint = os.path.join(config.train.dir, checkpoint_name) last_epoch, step = dlcommon.utils.load_checkpoint( _model, None, checkpoint) load_from_checkpoint(G, config.checkpoint.g) load_from_checkpoint(D, config.checkpoint.d) load_from_checkpoint(E, config.checkpoint.e) # build datasets dataloaders = build_dataloaders(config) # calculation method for anomaly score from torch.nn import MSELoss class MSEInstances: def __init__(self): self.mse_elements = MSELoss(reduction='none') def __call__(self, input, target): loss_elements = self.mse_elements(input, target) loss_elements = torch.flatten(loss_elements, start_dim=1) loss_instances = loss_elements.mean(axis=1) return loss_instances score_fn = MSEInstances() # train loop inference(config=config, G=G, D=D, E=E, dataloaders=dataloaders, hooks=hooks, score_fn=score_fn)
def run(config): # prepare directories prepare_directories(config) # build hooks hooks = build_hooks(config) # build model model = build_model(config, hooks) # build loss loss = build_loss(config) loss_fn = hooks.loss_fn hooks.loss_fn = lambda **kwargs: loss_fn(loss_fn=loss, **kwargs) # build optimizer if 'no_bias_decay' in config.train and config.train.no_bias_decay: if 'encoder_lr_ratio' in config.train: encoder_lr_ratio = config.train.encoder_lr_ratio group_decay_encoder, group_no_decay_encoder = group_weight( model.encoder) base_lr = config.optimizer.params.lr params = [{ 'params': model.product.parameters(), 'lr': base_lr }, { 'params': model.fc.parameters(), 'lr': base_lr }, { 'params': group_decay_encoder, 'lr': base_lr * encoder_lr_ratio }, { 'params': group_no_decay_encoder, 'lr': base_lr * encoder_lr_ratio, 'weight_decay': 0.0 }] else: group_decay, group_no_decay = group_weight(model) params = [{ 'params': group_decay }, { 'params': group_no_decay, 'weight_decay': 0.0 }] elif 'encoder_lr_ratio' in config.train: denom = config.train.encoder_lr_ratio base_lr = config.optimizer.params.lr params = [{ 'params': model.encoder.parameters(), 'lr': base_lr * denom }, { 'params': model.fc.parameters(), 'lr': base_lr }, { 'params': model.product.parameters(), 'lr': base_lr }] else: params = model.parameters() optimizer = build_optimizer(config, params=params) model = model.cuda() # load checkpoint checkpoint = dlcommon.utils.get_initial_checkpoint(config) if checkpoint is not None: last_epoch, step = dlcommon.utils.load_checkpoint( model, optimizer, checkpoint) print('epoch, step:', last_epoch, step) else: last_epoch, step = -1, -1 model, optimizer = to_data_parallel(config, model, optimizer) # build scheduler scheduler = build_scheduler(config, optimizer=optimizer, last_epoch=last_epoch) # build datasets dataloaders = build_dataloaders(config) # build summary writer writer = SummaryWriter(logdir=config.train.dir) logger_fn = hooks.logger_fn hooks.logger_fn = lambda **kwargs: logger_fn(writer=writer, **kwargs) # train loop train(config=config, model=model, optimizer=optimizer, scheduler=scheduler, dataloaders=dataloaders, hooks=hooks, last_epoch=last_epoch + 1)