def train(config):
    bzu.log.init(config['log_dir'])
    bzu.log.save_config(config)
    teacher_config = bzu.log.load_config(config['teacher_args']['model_path'])

    data_train, data_val = load_data(**config['data_args'])
    criterion = LocationLoss(**config['camera_args'])
    net = ImagePolicyModelSS(
        config['model_args']['backbone'],
        pretrained=config['model_args']['imagenet_pretrained']).to(
            config['device'])
    teacher_net = BirdViewPolicyModelSS(
        teacher_config['model_args']['backbone']).to(config['device'])
    teacher_net.load_state_dict(
        torch.load(config['teacher_args']['model_path']))
    teacher_net.eval()

    coord_converter = CoordConverter(**config['camera_args'])

    optim = torch.optim.Adam(net.parameters(),
                             lr=config['optimizer_args']['lr'])

    for epoch in tqdm.tqdm(range(config['max_epoch'] + 1), desc='Epoch'):
        train_or_eval(coord_converter, criterion, net, teacher_net, data_train,
                      optim, True, config, epoch == 0)
        train_or_eval(coord_converter, criterion, net, teacher_net, data_val,
                      None, False, config, epoch == 0)

        if epoch in SAVE_EPOCHS:
            torch.save(net.state_dict(),
                       str(Path(config['log_dir']) / ('model-%d.th' % epoch)))

        bzu.log.end_epoch()
def load_image_model(backbone, ckpt, device='cuda'):
    net = ImagePolicyModelSS(
        backbone,
        all_branch=True
    ).to(device)
   
    net.load_state_dict(torch.load(ckpt, map_location=device))
    return net