class Session: def __init__(self, config, net=None): self.log_dir = config.log_dir self.model_dir = config.model_dir self.net = net self.best_val_acc = 0.0 self.tb_writer = SummaryWriter(log_dir=self.log_dir) self.clock = TrainClock() def save_checkpoint(self, name): ckp_path = os.path.join(self.model_dir, name) tmp = { 'net': self.net, 'best_val_acc': self.best_val_acc, 'clock': self.clock.make_checkpoint(), } torch.save(tmp, ckp_path) def load_checkpoint(self, ckp_path): checkpoint = torch.load(ckp_path) self.net = checkpoint['net'] self.clock.restore_checkpoint(checkpoint['clock']) self.best_val_acc = checkpoint['best_val_acc']
class Session: def __init__(self, config, net=None): self.log_dir = config.log_dir self.model_dir = config.model_dir self.net = net self.best_val_loss = np.inf self.tb_writer = SummaryWriter(log_dir=self.log_dir) self.clock = TrainClock() def save_checkpoint(self, name): ckp_path = os.path.join(self.model_dir, name) tmp = { 'state_dict': self.net.state_dict(), 'best_val_loss': self.best_val_loss, 'clock': self.clock.make_checkpoint(), } torch.save(tmp, ckp_path) def load_checkpoint(self, ckp_path): checkpoint = torch.load(ckp_path) self.net.load_state_dict(checkpoint['state_dict']) self.clock.restore_checkpoint(checkpoint['clock']) self.best_val_loss = checkpoint['best_val_loss']
class BaseAgent(object): """Base trainer that provides commom training behavior. All trainer should be subclass of this class. """ def __init__(self, config): self.log_dir = config.log_dir self.model_dir = config.model_dir self.clock = TrainClock() self.device = config.device self.batch_size = config.batch_size # build network self.net = self.build_net() # set loss function self.set_loss_function() # set optimizer self.set_optimizer(config) # set tensorboard writer self.train_tb = SummaryWriter(os.path.join(self.log_dir, 'train.events')) self.val_tb = SummaryWriter(os.path.join(self.log_dir, 'val.events')) @abstractmethod def build_net(self): raise NotImplementedError def set_loss_function(self): """set loss function used in training""" self.criterion = nn.MSELoss().to(self.device) def set_optimizer(self, config): """set optimizer and lr scheduler used in training""" self.optimizer = optim.Adam(self.net.parameters(), config.lr) self.scheduler = optim.lr_scheduler.ExponentialLR(self.optimizer, config.lr_decay) def save_ckpt(self, name=None): """save checkpoint during training for future restore""" if name is None: save_path = os.path.join(self.model_dir, "ckpt_epoch{}.pth".format(self.clock.epoch)) print("Checkpoint saved at {}".format(save_path)) else: save_path = os.path.join(self.model_dir, "{}.pth".format(name)) if isinstance(self.net, nn.DataParallel): torch.save({ 'clock': self.clock.make_checkpoint(), 'model_state_dict': self.net.module.cpu().state_dict(), 'optimizer_state_dict': self.optimizer.state_dict(), 'scheduler_state_dict': self.scheduler.state_dict(), }, save_path) else: torch.save({ 'clock': self.clock.make_checkpoint(), 'model_state_dict': self.net.cpu().state_dict(), 'optimizer_state_dict': self.optimizer.state_dict(), 'scheduler_state_dict': self.scheduler.state_dict(), }, save_path) self.net.cuda() def load_ckpt(self, name=None): """load checkpoint from saved checkpoint""" name = name if name == 'latest' else "ckpt_epoch{}".format(name) load_path = os.path.join(self.model_dir, "{}.pth".format(name)) if not os.path.exists(load_path): raise ValueError("Checkpoint {} not exists.".format(load_path)) checkpoint = torch.load(load_path) print("Checkpoint loaded from {}".format(load_path)) if isinstance(self.net, nn.DataParallel): self.net.module.load_state_dict(checkpoint['model_state_dict']) else: self.net.load_state_dict(checkpoint['model_state_dict']) self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) self.scheduler.load_state_dict(checkpoint['scheduler_state_dict']) self.clock.restore_checkpoint(checkpoint['clock']) @abstractmethod def forward(self, data): pass def update_network(self, loss_dict): """update network by back propagation""" loss = sum(loss_dict.values()) self.optimizer.zero_grad() loss.backward() self.optimizer.step() def update_learning_rate(self): """record and update learning rate""" self.train_tb.add_scalar('learning_rate', self.optimizer.param_groups[-1]['lr'], self.clock.epoch) self.scheduler.step(self.clock.epoch) def record_losses(self, loss_dict, mode='train'): losses_values = {k: v.item() for k, v in loss_dict.items()} # record loss to tensorboard tb = self.train_tb if mode == 'train' else self.val_tb for k, v in losses_values.items(): tb.add_scalar(k, v, self.clock.step) def train_func(self, data): """one step of training""" self.net.train() outputs, losses = self.forward(data) self.update_network(losses) self.record_losses(losses, 'train') return outputs, losses def val_func(self, data): """one step of validation""" self.net.eval() with torch.no_grad(): outputs, losses = self.forward(data) self.record_losses(losses, 'validation') return outputs, losses def visualize_batch(self, data, mode, **kwargs): """write visualization results to tensorboard writer""" raise NotImplementedError
class VGGAgent(object): """Base trainer that provides commom training behavior. All trainer should be subclass of this class. """ def __init__(self, config): self.model_dir = config.model_dir self.clock = TrainClock() self.batch_size = config.batch_size # build network self.net = self.build_net(config) # set loss function self.set_loss_function() # set optimizer self.set_optimizer(config) def build_net(self, config): return get_network("VGG", config).cuda() def set_loss_function(self): """set loss function used in training""" self.criterion = nn.CrossEntropyLoss().cuda() def set_optimizer(self, config): """set optimizer and lr scheduler used in training""" self.optimizer = optim.Adam(self.net.parameters(), config.lr) self.scheduler = optim.lr_scheduler.ExponentialLR( self.optimizer, config.lr_decay) def save_ckpt(self, name=None): """save checkpoint during training for future restore""" if name is None: save_path = os.path.join( self.model_dir, "ckpt_epoch{}.pth".format(self.clock.epoch)) print("Checkpoint saved at {}".format(save_path)) else: save_path = os.path.join(self.model_dir, "{}.pth".format(name)) if isinstance(self.net, nn.DataParallel): torch.save( { 'clock': self.clock.make_checkpoint(), 'model_state_dict': self.net.module.cpu().state_dict(), 'optimizer_state_dict': self.optimizer.state_dict(), 'scheduler_state_dict': self.scheduler.state_dict(), }, save_path) else: torch.save( { 'clock': self.clock.make_checkpoint(), 'model_state_dict': self.net.cpu().state_dict(), 'optimizer_state_dict': self.optimizer.state_dict(), 'scheduler_state_dict': self.scheduler.state_dict(), }, save_path) self.net.cuda() def load_ckpt(self, name=None): """load checkpoint from saved checkpoint""" name = name if name == 'latest' else "ckpt_epoch{}".format(name) load_path = os.path.join(self.model_dir, "{}.pth".format(name)) if not os.path.exists(load_path): raise ValueError("Checkpoint {} not exists.".format(load_path)) checkpoint = torch.load(load_path) print("Checkpoint loaded from {}".format(load_path)) if isinstance(self.net, nn.DataParallel): self.net.module.load_state_dict(checkpoint['model_state_dict']) else: self.net.load_state_dict(checkpoint['model_state_dict']) self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) self.scheduler.load_state_dict(checkpoint['scheduler_state_dict']) self.clock.restore_checkpoint(checkpoint['clock']) def forward(self, data, label): output = self.net(data) losses = self.criterion(output, label) return output, {"loss": losses} def update_network(self, loss_dict): """update network by back propagation""" loss = sum(loss_dict.values()) self.optimizer.zero_grad() loss.backward() self.optimizer.step() def update_learning_rate(self): """record and update learning rate""" self.scheduler.step(self.clock.epoch) def train_func(self, data, label): """one step of training""" self.net.train() outputs, losses = self.forward(data, label) self.update_network(losses) return outputs, losses def val_func(self, data, label): """one step of validation""" self.net.eval() with torch.no_grad(): outputs, losses = self.forward(data, label) return outputs, losses
class BaseAgent(object): """Base trainer that provides commom training behavior. All trainer should be subclass of this class. """ def __init__(self, config): self.log_dir = config.log_dir self.model_dir = config.model_dir self.clock = TrainClock() self.device = config.device self.batch_size = config.batch_size # build network self.net = self.build_net(config).cuda() # set loss function self.set_loss_function() # set optimizer self.set_optimizer(config) @abstractmethod def build_net(self, config): raise NotImplementedError def set_loss_function(self): """set loss function used in training""" self.criterion = nn.MSELoss().cuda() def set_optimizer(self, config): """set optimizer and lr scheduler used in training""" self.optimizer = optim.Adam(self.net.parameters(), config.lr) self.scheduler = optim.lr_scheduler.StepLR(self.optimizer, config.lr_step_size) def save_ckpt(self, name=None): """save checkpoint during training for future restore""" if name is None: save_path = os.path.join( self.model_dir, "ckpt_epoch{}.pth".format(self.clock.epoch)) else: save_path = os.path.join(self.model_dir, name) if isinstance(self.net, nn.DataParallel): torch.save( { 'clock': self.clock.make_checkpoint(), 'model_state_dict': self.net.module.cpu().state_dict(), 'optimizer_state_dict': self.optimizer.state_dict(), 'scheduler_state_dict': self.scheduler.state_dict(), }, save_path) else: torch.save( { 'clock': self.clock.make_checkpoint(), 'model_state_dict': self.net.cpu().state_dict(), 'optimizer_state_dict': self.optimizer.state_dict(), 'scheduler_state_dict': self.scheduler.state_dict(), }, save_path) self.net.cuda() def load_ckpt(self, path=None): """load checkpoint from saved checkpoint""" if path is not None: load_path = path else: load_path = os.path.join(self.model_dir, "latest.pth.tar") checkpoint = torch.load(load_path) if isinstance(self.net, nn.DataParallel): self.net.module.load_state_dict(checkpoint['model_state_dict']) else: self.net.load_state_dict(checkpoint['model_state_dict']) self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) self.scheduler.load_state_dict(checkpoint['scheduler_state_dict']) self.clock.restore_checkpoint(checkpoint['clock']) @abstractmethod def forward(self, data): pass def update_network(self, loss_dict): """update network by back propagation""" loss = sum(loss_dict.values()) self.optimizer.zero_grad() loss.backward() self.optimizer.step() def update_learning_rate(self): self.scheduler.step(self.clock.epoch) def train_func(self, data): """one step of training""" self.net.train() outputs, losses = self.forward(data) self.update_network(losses) return outputs, losses def val_func(self, data): """one step of validation""" self.net.eval() with torch.no_grad(): outputs, losses = self.forward(data) return outputs, losses def visualize_batch(self, data, tb, **kwargs): """write visualization results to tensorboard writer""" raise NotImplementedError