def train(config): bzu.log.init(config['log_dir']) bzu.log.save_config(config) teacher_config = bzu.log.load_config(config['teacher_args']['model_path']) data_train, data_val = load_data(**config['data_args']) criterion = LocationLoss(**config['camera_args']) net = ImagePolicyModelSS( config['model_args']['backbone'], pretrained=config['model_args']['imagenet_pretrained']).to( config['device']) teacher_net = BirdViewPolicyModelSS( teacher_config['model_args']['backbone']).to(config['device']) teacher_net.load_state_dict( torch.load(config['teacher_args']['model_path'])) teacher_net.eval() coord_converter = CoordConverter(**config['camera_args']) optim = torch.optim.Adam(net.parameters(), lr=config['optimizer_args']['lr']) for epoch in tqdm.tqdm(range(config['max_epoch'] + 1), desc='Epoch'): train_or_eval(coord_converter, criterion, net, teacher_net, data_train, optim, True, config, epoch == 0) train_or_eval(coord_converter, criterion, net, teacher_net, data_val, None, False, config, epoch == 0) if epoch in SAVE_EPOCHS: torch.save(net.state_dict(), str(Path(config['log_dir']) / ('model-%d.th' % epoch))) bzu.log.end_epoch()
def load_image_model(backbone, ckpt, device='cuda'): net = ImagePolicyModelSS( backbone, all_branch=True ).to(device) net.load_state_dict(torch.load(ckpt, map_location=device)) return net