예제 #1
0
def train(model, epochs, train_gen, val_gen, lambd, rec_path, model_root):
    lr_schedule = PiecewiseSchedule(
        [[0, 1e-3], [epochs / 2, 5e-4], [epochs, 1e-4]], outside_value=1e-4)

    ### In our implementation, we keep lambd fixed.
    lambd_schedule = PiecewiseSchedule([[0, lambd], [epochs, lambd]],
                                       outside_value=lambd)

    callbacks = [
        AnnealEveryEpoch({
            "lr": lr_schedule,
            "kl_gauss": lambd_schedule
        }),
        ValidateRecordandSaveBest(val_gen, rec_path, model_root),
        ModelCheckpoint(filepath=os.path.join(model_root,
                                              "model_{epoch:02d}.h5"),
                        period=1)
    ]

    model.compile(optimizer=optimizers.Adam(), loss="mse")
    model.fit_generator(train_gen,
                        workers=4,
                        epochs=epochs,
                        callbacks=callbacks,
                        validation_data=val_gen)
예제 #2
0
def train(model, epochs, train_gen, val_gen, lambd, rec_path, model_path):
    lr_schedule = PiecewiseSchedule(
        [[0, 5e-4], [epochs / 2, 1e-5], [epochs, 5e-5]], outside_value=5e-4)

    ### In our implementation, we keep lambd fixed.
    lambda_schedule = PiecewiseSchedule([[0, lambd], [epochs, lambd]],
                                        outside_value=lambd)

    callbacks = [
        AnnealEveryEpoch({
            "lr": lr_schedule,
            "kl_gauss": lambda_schedule
        }),
        ValidateRecordandSaveBest(val_gen, rec_path, model_path),
    ]

    model.compile(optimizer=optimizers.Adam(), loss="mse")
    model.fit_generator(train_gen,
                        workers=4,
                        epochs=epochs,
                        callbacks=callbacks)
예제 #3
0
def main(dataset_type):
    if dataset_type == "bdd":
        dataset = BDDDataset(root="/home/jinkun/datasets/bdd_val")
    elif dataset_type == "kitti":
        dataset = KittiDataset(
            root="/home/jinkun/datasets/kitti/2011_09_26_drive_0019_extract")
    else:
        assert (0)

    dataloader = DataLoader(dataset=dataset, batch_size=1, shuffle=False)

    model = ConvLSTMMulti(args).cuda()
    model.eval()
    state_dict = torch.load("/home/jinkun/git/spc_trained.pt")
    model.load_state_dict(state_dict)

    guides = generate_guide_grid(args.bin_divide)
    actionsampler = ActionSampleManager(args, guides)

    exploration = PiecewiseSchedule([
        (0, 1.0),
        (args.epsilon_frames, 0.02),
    ],
                                    outside_value=0.02)

    for index, data in enumerate(dataloader):
        imgs, imgs_ori = data
        # imgs: sized [batch_size, frame, h, w, 3]
        imgs = torch.transpose(imgs, 2, 4).transpose(3, 4).cuda()
        imgs_ori = torch.transpose(imgs_ori, 2, 4).transpose(3, 4)
        obs_var = imgs.reshape(1, 9, 256, 256).cuda()
        obs = imgs[0, -1].transpose(0, 2).transpose(0, 1).cpu().numpy()
        obs_ori = imgs_ori[0, -1].transpose(0, 2).transpose(0, 1).numpy()
        action_var = torch.from_numpy(np.array([-1.0, 0.0])).repeat(
            1, args.frame_history_len - 1, 1).float().cuda()
        action, guidance_action, p = actionsampler.sample_action(
            net=model,
            obs=obs,
            obs_var=obs_var,
            action_var=action_var,
            exploration=exploration,
            step=0,
            testing=True)
        throttle = action[0] * 0.5 + 0.5
        steer = action[1] * 0.4
        obs_var = norm_image(obs_var).unsqueeze(0)
        action = torch.Tensor(action).cuda().unsqueeze(0)
        output = model(obs_var, action, training=False, action_var=action_var)

        action = action[0][0]
        guidance_action = guidance_action[0]
        draw_output(args, obs_ori, output, action, guidance_action, p, index)
예제 #4
0
파일: model_test.py 프로젝트: vipjml/nesgym
	def setUp(self):
		max_timesteps = 40000000
		exploration_schedule = PiecewiseSchedule(
	        [
	            (0, 1.0),
	            (1e5, 0.1),
	            (max_timesteps / 2, 0.01),
	        ], outside_value=0.01
	    )
		self.dqn = DoubleDQN(image_shape=(84, 84, 1),
                       num_actions=16,
                       training_starts=10000,
                       target_update_freq=4000,
                       training_batch_size=64,
                       exploration=exploration_schedule)
예제 #5
0
def init_models(args):
    train_net = ConvLSTMMulti(args)
    for param in train_net.parameters():
        param.requires_grad = True
    train_net.train()

    net = ConvLSTMMulti(args)
    for param in net.parameters():
        param.requires_grad = False
    net.eval()

    train_net, epoch = load_model(args.save_path,
                                  train_net,
                                  resume=args.resume)
    net.load_state_dict(train_net.state_dict())

    if torch.cuda.is_available():
        train_net = train_net.cuda()
        net = net.cuda()
        if args.data_parallel:
            train_net = torch.nn.DataParallel(train_net)
            net = torch.nn.DataParallel(net)
    optimizer = optim.Adam(train_net.parameters(), lr=args.lr, amsgrad=True)

    exploration = PiecewiseSchedule([
        (0, 1.0),
        (args.epsilon_frames, 0.02),
    ],
                                    outside_value=0.02)

    if args.resume:
        try:
            num_imgs_start = max(
                int(
                    open(os.path.join(args.save_path, 'log_train_torcs.txt')).
                    readlines()[-1].split(' ')[1]) - 1000, 0)
        except:
            num_imgs_start = 0
    else:
        num_imgs_start = 0

    return train_net, net, optimizer, epoch, exploration, num_imgs_start
예제 #6
0
def init_models(args):
    train_net = ConvLSTMMulti(args)
    for param in train_net.parameters():
        param.requires_grad = True

    if not args.eval:
        train_net.train()

    train_net, epoch, step = load_model(args,
                                        args.save_path,
                                        train_net,
                                        resume=args.resume)
    '''
    if not args.resume and not args.eval:
        pretrain_net = torch.load(args.pretrain_model)
        try:
            train_net.load_state_dict(pretrain_net)
        except:
            train_net.load_state_dict(pretrain_net, strict=False)
            print("strict load checkpoint {} failed. turn to non-strict loading mode".format(pretrain_net))
        train_net.conv_lstm.freeze_bn()
    '''

    if torch.cuda.is_available():
        train_net = train_net.cuda()
        if args.data_parallel:
            train_net = torch.nn.DataParallel(train_net)

    if args.optim == 'Adam':
        optimizer = optim.Adam(train_net.parameters(),
                               lr=args.lr,
                               amsgrad=True)
    elif args.optim == 'SGD':
        optimizer = optim.SGD(train_net.parameters(),
                              lr=args.lr,
                              momentum=0.9,
                              weight_decay=1e-4)
    else:
        assert (0)

    exploration = PiecewiseSchedule([
        (0, 1.0),
        (args.epsilon_frames, 0.02),
    ],
                                    outside_value=0.02)

    if args.resume:
        if args.env == 'torcs':
            num_imgs_start = max(
                int(
                    open(os.path.join(args.save_path, 'log_train_torcs.txt')).
                    readlines()[-1].split(' ')[1]) - 1000, 0)
        elif 'carla' in args.env:
            reward_log = os.path.join(args.save_path,
                                      'reward_train_{}.txt'.format(args.env))
            try:
                num_imgs_start = int(
                    open(reward_log).readlines()[-1].split(' ')[3])
            except:
                num_imgs_start = 0
    else:
        num_imgs_start = 0

    return train_net, optimizer, epoch, exploration, num_imgs_start