def __init__(self, config, net=None): self.log_dir = config.log_dir self.model_dir = config.model_dir self.net = net self.best_val_loss = np.inf self.tb_writer = SummaryWriter(log_dir=self.log_dir) self.clock = TrainClock()
def __init__(self, config): self.model_dir = config.model_dir self.clock = TrainClock() self.batch_size = config.batch_size # build network self.net = self.build_net(config) # set loss function self.set_loss_function() # set optimizer self.set_optimizer(config)
def __init__(self, config): self.log_dir = config.log_dir self.model_dir = config.model_dir self.clock = TrainClock() self.device = config.device self.batch_size = config.batch_size # build network self.net = self.build_net() # set loss function self.set_loss_function() # set optimizer self.set_optimizer(config) # set tensorboard writer self.train_tb = SummaryWriter(os.path.join(self.log_dir, 'train.events')) self.val_tb = SummaryWriter(os.path.join(self.log_dir, 'val.events'))
def __init__(self, train_spec, net=None): # setproctitle(config.exp_name) self.log_dir = train_spec.log_dir ensure_dir(self.log_dir) # logconf.set_output_file(os.path.join(self.log_dir, 'log.txt')) self.model_dir = train_spec.log_model_dir ensure_dir(self.model_dir) self.net = net self.clock = TrainClock()
def __init__(self, config): self.log_dir = config.log_dir self.model_dir = config.model_dir self.clock = TrainClock() self.batch_size = config.batch_size # build network self.net = self.build_net(config) # print('-----network architecture-----') # print(self.net) # set loss function self.set_loss_function() # set optimizer and scheduler self.set_optimizer(config) self.set_scheduler(config) # set tensorboard writer self.train_tb = SummaryWriter(os.path.join(self.log_dir, 'train.events')) self.val_tb = SummaryWriter(os.path.join(self.log_dir, 'val.events'))
class Session: def __init__(self, config, net=None): self.log_dir = config.log_dir self.model_dir = config.model_dir self.net = net self.best_val_loss = np.inf self.tb_writer = SummaryWriter(log_dir=self.log_dir) self.clock = TrainClock() def save_checkpoint(self, name): ckp_path = os.path.join(self.model_dir, name) tmp = { 'state_dict': self.net.state_dict(), 'best_val_loss': self.best_val_loss, 'clock': self.clock.make_checkpoint(), } torch.save(tmp, ckp_path) def load_checkpoint(self, ckp_path): checkpoint = torch.load(ckp_path) self.net.load_state_dict(checkpoint['state_dict']) self.clock.restore_checkpoint(checkpoint['clock']) self.best_val_loss = checkpoint['best_val_loss']
class Session: def __init__(self, config, net=None): self.log_dir = config.log_dir self.model_dir = config.model_dir self.net = net self.best_val_acc = 0.0 self.tb_writer = SummaryWriter(log_dir=self.log_dir) self.clock = TrainClock() def save_checkpoint(self, name): ckp_path = os.path.join(self.model_dir, name) tmp = { 'net': self.net, 'best_val_acc': self.best_val_acc, 'clock': self.clock.make_checkpoint(), } torch.save(tmp, ckp_path) def load_checkpoint(self, ckp_path): checkpoint = torch.load(ckp_path) self.net = checkpoint['net'] self.clock.restore_checkpoint(checkpoint['clock']) self.best_val_acc = checkpoint['best_val_acc']
parser = argparse.ArgumentParser() parser.add_argument('--val', action='store_true', help='weither to run validation') parser.add_argument('--bench', action='store_true', help = 'weither to generate results on benchmark dataset') parser.add_argument('--smooth', action='store_true', help = 'weither to generator smootGrad results on benchmark dataset') parser.add_argument('--gen', action = 'store_true', help = 'weither to generate Large eps adversarial examples on benchmark data') parser.add_argument('--resume', type=str, default=None, help='checkpoint path') args = parser.parse_args() clock = TrainClock() clock.epoch = 21 net = ant_model() net.cuda() if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) check_point = torch.load(args.resume) net.load_state_dict(check_point['state_dict']) print('Modeled loaded from {} with metrics:'.format(args.resume)) else: print("=> no checkpoint found at '{}'".format(args.resume)) base_path = os.path.split(args.resume)[0] else:
class BaseAgent(object): """Base trainer that provides commom training behavior. All trainer should be subclass of this class. """ def __init__(self, config): self.log_dir = config.log_dir self.model_dir = config.model_dir self.clock = TrainClock() self.device = config.device self.batch_size = config.batch_size # build network self.net = self.build_net() # set loss function self.set_loss_function() # set optimizer self.set_optimizer(config) # set tensorboard writer self.train_tb = SummaryWriter(os.path.join(self.log_dir, 'train.events')) self.val_tb = SummaryWriter(os.path.join(self.log_dir, 'val.events')) @abstractmethod def build_net(self): raise NotImplementedError def set_loss_function(self): """set loss function used in training""" self.criterion = nn.MSELoss().to(self.device) def set_optimizer(self, config): """set optimizer and lr scheduler used in training""" self.optimizer = optim.Adam(self.net.parameters(), config.lr) self.scheduler = optim.lr_scheduler.ExponentialLR(self.optimizer, config.lr_decay) def save_ckpt(self, name=None): """save checkpoint during training for future restore""" if name is None: save_path = os.path.join(self.model_dir, "ckpt_epoch{}.pth".format(self.clock.epoch)) print("Checkpoint saved at {}".format(save_path)) else: save_path = os.path.join(self.model_dir, "{}.pth".format(name)) if isinstance(self.net, nn.DataParallel): torch.save({ 'clock': self.clock.make_checkpoint(), 'model_state_dict': self.net.module.cpu().state_dict(), 'optimizer_state_dict': self.optimizer.state_dict(), 'scheduler_state_dict': self.scheduler.state_dict(), }, save_path) else: torch.save({ 'clock': self.clock.make_checkpoint(), 'model_state_dict': self.net.cpu().state_dict(), 'optimizer_state_dict': self.optimizer.state_dict(), 'scheduler_state_dict': self.scheduler.state_dict(), }, save_path) self.net.cuda() def load_ckpt(self, name=None): """load checkpoint from saved checkpoint""" name = name if name == 'latest' else "ckpt_epoch{}".format(name) load_path = os.path.join(self.model_dir, "{}.pth".format(name)) if not os.path.exists(load_path): raise ValueError("Checkpoint {} not exists.".format(load_path)) checkpoint = torch.load(load_path) print("Checkpoint loaded from {}".format(load_path)) if isinstance(self.net, nn.DataParallel): self.net.module.load_state_dict(checkpoint['model_state_dict']) else: self.net.load_state_dict(checkpoint['model_state_dict']) self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) self.scheduler.load_state_dict(checkpoint['scheduler_state_dict']) self.clock.restore_checkpoint(checkpoint['clock']) @abstractmethod def forward(self, data): pass def update_network(self, loss_dict): """update network by back propagation""" loss = sum(loss_dict.values()) self.optimizer.zero_grad() loss.backward() self.optimizer.step() def update_learning_rate(self): """record and update learning rate""" self.train_tb.add_scalar('learning_rate', self.optimizer.param_groups[-1]['lr'], self.clock.epoch) self.scheduler.step(self.clock.epoch) def record_losses(self, loss_dict, mode='train'): losses_values = {k: v.item() for k, v in loss_dict.items()} # record loss to tensorboard tb = self.train_tb if mode == 'train' else self.val_tb for k, v in losses_values.items(): tb.add_scalar(k, v, self.clock.step) def train_func(self, data): """one step of training""" self.net.train() outputs, losses = self.forward(data) self.update_network(losses) self.record_losses(losses, 'train') return outputs, losses def val_func(self, data): """one step of validation""" self.net.eval() with torch.no_grad(): outputs, losses = self.forward(data) self.record_losses(losses, 'validation') return outputs, losses def visualize_batch(self, data, mode, **kwargs): """write visualization results to tensorboard writer""" raise NotImplementedError
class VGGAgent(object): """Base trainer that provides commom training behavior. All trainer should be subclass of this class. """ def __init__(self, config): self.model_dir = config.model_dir self.clock = TrainClock() self.batch_size = config.batch_size # build network self.net = self.build_net(config) # set loss function self.set_loss_function() # set optimizer self.set_optimizer(config) def build_net(self, config): return get_network("VGG", config).cuda() def set_loss_function(self): """set loss function used in training""" self.criterion = nn.CrossEntropyLoss().cuda() def set_optimizer(self, config): """set optimizer and lr scheduler used in training""" self.optimizer = optim.Adam(self.net.parameters(), config.lr) self.scheduler = optim.lr_scheduler.ExponentialLR( self.optimizer, config.lr_decay) def save_ckpt(self, name=None): """save checkpoint during training for future restore""" if name is None: save_path = os.path.join( self.model_dir, "ckpt_epoch{}.pth".format(self.clock.epoch)) print("Checkpoint saved at {}".format(save_path)) else: save_path = os.path.join(self.model_dir, "{}.pth".format(name)) if isinstance(self.net, nn.DataParallel): torch.save( { 'clock': self.clock.make_checkpoint(), 'model_state_dict': self.net.module.cpu().state_dict(), 'optimizer_state_dict': self.optimizer.state_dict(), 'scheduler_state_dict': self.scheduler.state_dict(), }, save_path) else: torch.save( { 'clock': self.clock.make_checkpoint(), 'model_state_dict': self.net.cpu().state_dict(), 'optimizer_state_dict': self.optimizer.state_dict(), 'scheduler_state_dict': self.scheduler.state_dict(), }, save_path) self.net.cuda() def load_ckpt(self, name=None): """load checkpoint from saved checkpoint""" name = name if name == 'latest' else "ckpt_epoch{}".format(name) load_path = os.path.join(self.model_dir, "{}.pth".format(name)) if not os.path.exists(load_path): raise ValueError("Checkpoint {} not exists.".format(load_path)) checkpoint = torch.load(load_path) print("Checkpoint loaded from {}".format(load_path)) if isinstance(self.net, nn.DataParallel): self.net.module.load_state_dict(checkpoint['model_state_dict']) else: self.net.load_state_dict(checkpoint['model_state_dict']) self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) self.scheduler.load_state_dict(checkpoint['scheduler_state_dict']) self.clock.restore_checkpoint(checkpoint['clock']) def forward(self, data, label): output = self.net(data) losses = self.criterion(output, label) return output, {"loss": losses} def update_network(self, loss_dict): """update network by back propagation""" loss = sum(loss_dict.values()) self.optimizer.zero_grad() loss.backward() self.optimizer.step() def update_learning_rate(self): """record and update learning rate""" self.scheduler.step(self.clock.epoch) def train_func(self, data, label): """one step of training""" self.net.train() outputs, losses = self.forward(data, label) self.update_network(losses) return outputs, losses def val_func(self, data, label): """one step of validation""" self.net.eval() with torch.no_grad(): outputs, losses = self.forward(data, label) return outputs, losses
def main(): parser = argparse.ArgumentParser() parser.add_argument('-c', '--continue', dest='continue_path', type=str, required=False) parser.add_argument('-g', '--gpu_ids', type=int, default=0, required=False) args = parser.parse_args() os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_ids) config.device = torch.device( "cuda:0" if torch.cuda.is_available() else "cpu") config.isTrain = True if not os.path.exists('train_log'): os.symlink(config.exp_dir, 'train_log') # get dataset train_loader = get_dataloader("train", batch_size=config.batch_size) val_loader = get_dataloader("test", batch_size=config.batch_size) val_cycle = cycle(val_loader) dataset_size = len(train_loader) print('The number of training motions = %d' % (dataset_size * config.batch_size)) # create tensorboard writer train_tb = SummaryWriter(os.path.join(config.log_dir, 'train.events')) val_tb = SummaryWriter(os.path.join(config.log_dir, 'val.events')) # get model net = CycleGANModel(config) net.print_networks(True) # start training clock = TrainClock() net.train() for e in range(config.nr_epochs): # begin iteration pbar = tqdm(train_loader) for b, data in enumerate(pbar): net.train() net.set_input( data) # unpack data from dataset and apply preprocessing net.optimize_parameters( ) # calculate loss functions, get gradients, update network weights # get loss losses_values = net.get_current_losses() # update tensorboard train_tb.add_scalars('train_loss', losses_values, global_step=clock.step) # visualize if clock.step % config.visualize_frequency == 0: motion_dict = net.infer() for k, v in motion_dict.items(): phase = 'h' if k[-1] == 'A' else 'nh' motion3d = train_loader.dataset.preprocess_inv( v.detach().cpu().numpy()[0], phase) img = plot_motion(motion3d, phase) train_tb.add_image(k, img, global_step=clock.step) pbar.set_description("EPOCH[{}][{}/{}]".format( e, b, len(train_loader))) pbar.set_postfix(OrderedDict(losses_values)) # validation if clock.step % config.val_frequency == 0: net.eval() data = next(val_cycle) net.set_input(data) net.forward() losses_values = net.get_current_losses() val_tb.add_scalars('val_loss', losses_values, global_step=clock.step) # visualize if clock.step % config.visualize_frequency == 0: motion_dict = net.infer() for k, v in motion_dict.items(): phase = 'h' if k[-1] == 'A' else 'nh' motion3d = val_loader.dataset.preprocess_inv( v.detach().cpu().numpy()[0], phase) img = plot_motion(motion3d, phase) val_tb.add_image(k, img, global_step=clock.step) clock.tick() # leraning_rate to tensorboarrd lr = net.optimizers[0].param_groups[0]['lr'] train_tb.add_scalar("learning_rate", lr, global_step=clock.step) if clock.epoch % config.save_frequency == 0: net.save_networks(epoch=e) clock.tock() net.update_learning_rate( ) # update learning rates at the end of every epoch.
DEVICE = torch.device('cuda:{}'.format(args.d)) if args.exp is None: cur_dir = os.path.realpath('./') args.exp = cur_dir.split(os.path.sep)[-1] log_dir = os.path.join('../../logs', args.exp) exp_dir = os.path.join('../../exps', args.exp) train_res_path = os.path.join(exp_dir, 'train_results.txt') val_res_path = os.path.join(exp_dir, 'val_results.txt') final_res_path = os.path.join(exp_dir, 'final_results.txt') if not os.path.exists(exp_dir): os.mkdir(exp_dir) save_args(args, exp_dir) writer = SummaryWriter(log_dir) clock = TrainClock() learning_rate_policy = [[5, 0.01], [3, 0.001], [2, 0.0001]] get_learing_rate = MultiStageLearningRatePolicy(learning_rate_policy) def adjust_learning_rate(optimizer, epoch): #global get_lea lr = get_learing_rate(epoch) for param_group in optimizer.param_groups: param_group['lr'] = lr torch.backends.cudnn.benchmark = True ds_train = create_train_dataset(args.batch_size)
class BaseAgent(object): """Base trainer that provides commom training behavior. All trainer should be subclass of this class. """ def __init__(self, config): self.log_dir = config.log_dir self.model_dir = config.model_dir self.clock = TrainClock() self.device = config.device self.batch_size = config.batch_size # build network self.net = self.build_net(config).cuda() # set loss function self.set_loss_function() # set optimizer self.set_optimizer(config) @abstractmethod def build_net(self, config): raise NotImplementedError def set_loss_function(self): """set loss function used in training""" self.criterion = nn.MSELoss().cuda() def set_optimizer(self, config): """set optimizer and lr scheduler used in training""" self.optimizer = optim.Adam(self.net.parameters(), config.lr) self.scheduler = optim.lr_scheduler.StepLR(self.optimizer, config.lr_step_size) def save_ckpt(self, name=None): """save checkpoint during training for future restore""" if name is None: save_path = os.path.join( self.model_dir, "ckpt_epoch{}.pth".format(self.clock.epoch)) else: save_path = os.path.join(self.model_dir, name) if isinstance(self.net, nn.DataParallel): torch.save( { 'clock': self.clock.make_checkpoint(), 'model_state_dict': self.net.module.cpu().state_dict(), 'optimizer_state_dict': self.optimizer.state_dict(), 'scheduler_state_dict': self.scheduler.state_dict(), }, save_path) else: torch.save( { 'clock': self.clock.make_checkpoint(), 'model_state_dict': self.net.cpu().state_dict(), 'optimizer_state_dict': self.optimizer.state_dict(), 'scheduler_state_dict': self.scheduler.state_dict(), }, save_path) self.net.cuda() def load_ckpt(self, path=None): """load checkpoint from saved checkpoint""" if path is not None: load_path = path else: load_path = os.path.join(self.model_dir, "latest.pth.tar") checkpoint = torch.load(load_path) if isinstance(self.net, nn.DataParallel): self.net.module.load_state_dict(checkpoint['model_state_dict']) else: self.net.load_state_dict(checkpoint['model_state_dict']) self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) self.scheduler.load_state_dict(checkpoint['scheduler_state_dict']) self.clock.restore_checkpoint(checkpoint['clock']) @abstractmethod def forward(self, data): pass def update_network(self, loss_dict): """update network by back propagation""" loss = sum(loss_dict.values()) self.optimizer.zero_grad() loss.backward() self.optimizer.step() def update_learning_rate(self): self.scheduler.step(self.clock.epoch) def train_func(self, data): """one step of training""" self.net.train() outputs, losses = self.forward(data) self.update_network(losses) return outputs, losses def val_func(self, data): """one step of validation""" self.net.eval() with torch.no_grad(): outputs, losses = self.forward(data) return outputs, losses def visualize_batch(self, data, tb, **kwargs): """write visualization results to tensorboard writer""" raise NotImplementedError