Esempio n. 1
0
def train(config):
    bzu.log.init(config['log_dir'])
    bzu.log.save_config(config)

    data_train, data_val = load_data(**config['data_args'])
    criterion = LocationLoss(w=192, h=192, choice='l1')
    net = BirdViewPolicyModelSS(config['model_args']['backbone']).to(config['device'])

    if config['resume']:
        log_dir = Path(config['log_dir'])
        checkpoints = list(log_dir.glob('model-*.th'))
        checkpoint = str(checkpoints[-1])
        print ("load %s"%checkpoint)
        net.load_state_dict(torch.load(checkpoint))

    optim = torch.optim.Adam(net.parameters(), lr=config['optimizer_args']['lr'])

    for epoch in tqdm.tqdm(range(config['max_epoch']+1), desc='Epoch'):
        train_or_eval(criterion, net, data_train, optim, True, config, epoch == 0)
        train_or_eval(criterion, net, data_val, None, False, config, epoch == 0)

        if epoch in SAVE_EPOCHS:
            torch.save(
                    net.state_dict(),
                    str(Path(config['log_dir']) / ('model-%d.th' % epoch)))

        bzu.log.end_epoch()
def train(config):
    # bzu.log is a saver.Experiment instance
    bzu.log.init(config['log_dir'])
    bzu.log.save_config(config)

    # earlier import re shown here
    # from utils.datasets.birdview_lmdb import get_birdview as load_data
    # data contains frames. frames contain (birdview, location, command, speed)
    data_train, data_val = load_data(**config['data_args'])
    criterion = LocationLoss(w=192, h=192, choice='l1')
    net = BirdViewPolicyModelSS(config['model_args']['backbone']).to(config['device'])

    # load most recent existing model only if --resume tag specified
    if config['resume']:
        log_dir = Path(config['log_dir'])
        checkpoints = list(log_dir.glob('model-*.th'))
        checkpoint = str(checkpoints[-1])
        print("load %s" % checkpoint)
        net.load_state_dict(torch.load(checkpoint))

    optim = torch.optim.Adam(net.parameters(), lr=config['optimizer_args']['lr'])

    for epoch in tqdm.tqdm(range(config['max_epoch'] + 1), desc='Epoch'):
        # train
        train_or_eval(criterion, net, data_train, optim,  True, config, epoch == 0)
        # evaluate
        train_or_eval(criterion, net,   data_val,  None, False, config, epoch == 0)

        if epoch in SAVE_EPOCHS:
            torch.save(
                net.state_dict(),
                str(Path(config['log_dir']) / ('model-%d.th' % epoch)))

        bzu.log.end_epoch()
def train(config):
    bzu.log.init(config['log_dir'])
    bzu.log.save_config(config)
    teacher_config = bzu.log.load_config(config['teacher_args']['model_path'])

    data_train, data_val = load_data(**config['data_args'])
    criterion = LocationLoss(**config['camera_args'])
    net = ImagePolicyModelSS(
        config['model_args']['backbone'],
        pretrained=config['model_args']['imagenet_pretrained']).to(
            config['device'])
    teacher_net = BirdViewPolicyModelSS(
        teacher_config['model_args']['backbone']).to(config['device'])
    teacher_net.load_state_dict(
        torch.load(config['teacher_args']['model_path']))
    teacher_net.eval()

    coord_converter = CoordConverter(**config['camera_args'])

    optim = torch.optim.Adam(net.parameters(),
                             lr=config['optimizer_args']['lr'])

    for epoch in tqdm.tqdm(range(config['max_epoch'] + 1), desc='Epoch'):
        train_or_eval(coord_converter, criterion, net, teacher_net, data_train,
                      optim, True, config, epoch == 0)
        train_or_eval(coord_converter, criterion, net, teacher_net, data_val,
                      None, False, config, epoch == 0)

        if epoch in SAVE_EPOCHS:
            torch.save(net.state_dict(),
                       str(Path(config['log_dir']) / ('model-%d.th' % epoch)))

        bzu.log.end_epoch()
def load_birdview_model(backbone, ckpt, device='cuda'):
    teacher_net = BirdViewPolicyModelSS(backbone, all_branch=True).to(device)
    teacher_net.load_state_dict(torch.load(ckpt))
    
    return teacher_net