예제 #1
0
def load_policy_model(args, environment, device, folder=None):
    parent_folder = './checkpoint/policy'
    path = folder if folder is not None else parent_folder

    model = Policy(environment['action'],
                   net=args.encoder,
                   pretrained=args.pretrained,
                   input=environment['input_size'])
    model.load_state_dict(torch.load(f'{path}/best_model.ckpt'))
    model = model.to(device)
    model.eval()
    return model
예제 #2
0
##### Policy #################################################
from models.policy import Policy

lr = 3e-4
gamma = 1
lambd_entropy = 0.3
# policy = Policy(hidden_dim=2, rnn_type='lstm')
policy = Policy(hidden_dim=4, input_dim=23, rnn_type=None)
r_neg = 5
r_pos = 5

if evaluation:
    checkpoint = torch.load(
        "/home/chenwy/DynamicLightEnlighten/bdd100k_seg/policy_model/image_lstm.vgg.avgpool.argmax_delta.clip1.2.action-mean0.975_entropy.0_gamma.1_lr1e4_update.5_2019-02-01-13-05/model_best.pth.tar"
    )
    policy.load_state_dict(checkpoint['state_dict'])
    policy.eval()
else:
    # params_list = [{'params': policy.vgg.parameters(), 'lr': lr},]
    # params_list.append({'params': policy.rnn.parameters(), 'lr': lr*10})
    # params_list = [{'params': policy.resnet.parameters(), 'lr': lr},]
    ##################################
    # params_list = [{'params': policy.fcn.pretrained.parameters(), 'lr': lr*10},
    #                 {'params': policy.fcn.head.parameters(), 'lr': lr*10}]
    ##################################
    params_list = [{'params': policy.msn.parameters(), 'lr': lr * 10}]
    ##################################

policy = policy.cuda()
####################################################