def train(config): bzu.log.init(config['log_dir']) bzu.log.save_config(config) data_train, data_val = load_data(**config['data_args']) criterion = LocationLoss(w=192, h=192, choice='l1') net = BirdViewPolicyModelSS(config['model_args']['backbone']).to(config['device']) if config['resume']: log_dir = Path(config['log_dir']) checkpoints = list(log_dir.glob('model-*.th')) checkpoint = str(checkpoints[-1]) print ("load %s"%checkpoint) net.load_state_dict(torch.load(checkpoint)) optim = torch.optim.Adam(net.parameters(), lr=config['optimizer_args']['lr']) for epoch in tqdm.tqdm(range(config['max_epoch']+1), desc='Epoch'): train_or_eval(criterion, net, data_train, optim, True, config, epoch == 0) train_or_eval(criterion, net, data_val, None, False, config, epoch == 0) if epoch in SAVE_EPOCHS: torch.save( net.state_dict(), str(Path(config['log_dir']) / ('model-%d.th' % epoch))) bzu.log.end_epoch()
def train(config): # bzu.log is a saver.Experiment instance bzu.log.init(config['log_dir']) bzu.log.save_config(config) # earlier import re shown here # from utils.datasets.birdview_lmdb import get_birdview as load_data # data contains frames. frames contain (birdview, location, command, speed) data_train, data_val = load_data(**config['data_args']) criterion = LocationLoss(w=192, h=192, choice='l1') net = BirdViewPolicyModelSS(config['model_args']['backbone']).to(config['device']) # load most recent existing model only if --resume tag specified if config['resume']: log_dir = Path(config['log_dir']) checkpoints = list(log_dir.glob('model-*.th')) checkpoint = str(checkpoints[-1]) print("load %s" % checkpoint) net.load_state_dict(torch.load(checkpoint)) optim = torch.optim.Adam(net.parameters(), lr=config['optimizer_args']['lr']) for epoch in tqdm.tqdm(range(config['max_epoch'] + 1), desc='Epoch'): # train train_or_eval(criterion, net, data_train, optim, True, config, epoch == 0) # evaluate train_or_eval(criterion, net, data_val, None, False, config, epoch == 0) if epoch in SAVE_EPOCHS: torch.save( net.state_dict(), str(Path(config['log_dir']) / ('model-%d.th' % epoch))) bzu.log.end_epoch()
def train(config): bzu.log.init(config['log_dir']) bzu.log.save_config(config) teacher_config = bzu.log.load_config(config['teacher_args']['model_path']) data_train, data_val = load_data(**config['data_args']) criterion = LocationLoss(**config['camera_args']) net = ImagePolicyModelSS( config['model_args']['backbone'], pretrained=config['model_args']['imagenet_pretrained']).to( config['device']) teacher_net = BirdViewPolicyModelSS( teacher_config['model_args']['backbone']).to(config['device']) teacher_net.load_state_dict( torch.load(config['teacher_args']['model_path'])) teacher_net.eval() coord_converter = CoordConverter(**config['camera_args']) optim = torch.optim.Adam(net.parameters(), lr=config['optimizer_args']['lr']) for epoch in tqdm.tqdm(range(config['max_epoch'] + 1), desc='Epoch'): train_or_eval(coord_converter, criterion, net, teacher_net, data_train, optim, True, config, epoch == 0) train_or_eval(coord_converter, criterion, net, teacher_net, data_val, None, False, config, epoch == 0) if epoch in SAVE_EPOCHS: torch.save(net.state_dict(), str(Path(config['log_dir']) / ('model-%d.th' % epoch))) bzu.log.end_epoch()
def load_birdview_model(backbone, ckpt, device='cuda'): teacher_net = BirdViewPolicyModelSS(backbone, all_branch=True).to(device) teacher_net.load_state_dict(torch.load(ckpt)) return teacher_net