예제 #1
0
    # prepare networks
    M = args.layer_size
    network = Mlp(
        input_size=obs_dim,
        output_size=action_dim,
        hidden_sizes=[M, M],
        output_activation=F.tanh,
    ).to(device)

    optimizer = optim.Adam(network.parameters(), lr=args.lr)
    epch = 0

    if args.load_model is not None:
        if os.path.isfile(args.load_model):
            checkpoint = torch.load(args.load_model)
            network.load_state_dict(checkpoint['network_state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            t_loss = checkpoint['train_loss']
            epch = checkpoint['epoch']
            print('Loading model: {}. Resuming from epoch: {}'.format(
                args.load_model, epch))
        else:
            print('Model: {} not found'.format(args.load_model))

    best_loss = np.Inf
    for epoch in range(epch, args.epochs):
        t_loss = train(network, dataloader, optimizer, epoch, device)
        print('=> epoch: {} Average Train loss: {:.4f}'.format(epoch, t_loss))

        if use_tb:
            logger.add_scalar(log_dir + '/train-loss', t_loss, epoch)
예제 #2
0
def experiment(variant):
    eval_env = gym.make(variant['env_name'])
    expl_env = eval_env
    obs_dim = expl_env.observation_space.low.size
    action_dim = eval_env.action_space.low.size

    M = variant['layer_size']
    # q and policy netwroks
    qf1 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=[M, M],
    ).to(ptu.device)
    qf2 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=[M, M],
    ).to(ptu.device)
    target_qf1 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=[M, M],
    ).to(ptu.device)
    target_qf2 = FlattenMlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=[M, M],
    ).to(ptu.device)

    # initialize with bc or not
    if variant['bc_model'] is None:
        policy = TanhGaussianPolicy(
            obs_dim=obs_dim,
            action_dim=action_dim,
            hidden_sizes=[M, M],
        ).to(ptu.device)
    else:
        bc_model = Mlp(
            input_size=obs_dim,
            output_size=action_dim,
            hidden_sizes=[64, 64],
            output_activation=F.tanh,
        ).to(ptu.device)

        checkpoint = torch.load(variant['bc_model'], map_location=map_location)
        bc_model.load_state_dict(checkpoint['network_state_dict'])
        print('Loading bc model: {}'.format(variant['bc_model']))

        # policy initialized with bc
        policy = TanhGaussianPolicy_BC(
            obs_dim=obs_dim,
            action_dim=action_dim,
            mean_network=bc_model,
            hidden_sizes=[M, M],
        ).to(ptu.device)

    # if bonus: define bonus networks
    if not variant['offline']:
        bonus_layer_size = variant['bonus_layer_size']
        bonus_network = Mlp(
            input_size=obs_dim + action_dim,
            output_size=1,
            hidden_sizes=[bonus_layer_size, bonus_layer_size],
            output_activation=F.sigmoid,
        ).to(ptu.device)

        checkpoint = torch.load(variant['bonus_path'],
                                map_location=map_location)
        bonus_network.load_state_dict(checkpoint['network_state_dict'])
        print('Loading bonus model: {}'.format(variant['bonus_path']))

        if variant['initialize_Q'] and bonus_layer_size == M:
            target_qf1.load_state_dict(checkpoint['network_state_dict'])
            target_qf2.load_state_dict(checkpoint['network_state_dict'])
            print('Initialize QF1 and QF2 with the bonus model: {}'.format(
                variant['bonus_path']))
        if variant['initialize_Q'] and bonus_layer_size != M:
            print(
                ' Size mismatch between Q and bonus- Turining off the initialization'
            )

    # eval_policy = MakeDeterministic(policy)
    eval_path_collector = CustomMDPPathCollector(eval_env, )
    expl_path_collector = MdpPathCollector(
        expl_env,
        policy,
    )
    buffer_filename = None
    if variant['buffer_filename'] is not None:
        buffer_filename = variant['buffer_filename']

    replay_buffer = EnvReplayBuffer(
        variant['replay_buffer_size'],
        expl_env,
    )

    dataset = eval_env.unwrapped.get_dataset()

    load_hdf5(dataset, replay_buffer, max_size=variant['replay_buffer_size'])

    if variant['normalize']:
        obs_mu, obs_std = dataset['observations'].mean(
            axis=0), dataset['observations'].std(axis=0)
        bonus_norm_param = [obs_mu, obs_std]
    else:
        bonus_norm_param = [None] * 2

    # shift the reward
    if variant['reward_shift'] is not None:
        rewards_shift_param = min(dataset['rewards']) - variant['reward_shift']
        print('.... reward is shifted : {} '.format(rewards_shift_param))
    else:
        rewards_shift_param = None

    if variant['offline']:
        trainer = SACTrainer(env=eval_env,
                             policy=policy,
                             qf1=qf1,
                             qf2=qf2,
                             target_qf1=target_qf1,
                             target_qf2=target_qf2,
                             rewards_shift_param=rewards_shift_param,
                             **variant['trainer_kwargs'])
        print('Agent of type offline SAC created')

    elif variant['bonus'] == 'bonus_add':
        trainer = SAC_BonusTrainer(
            env=eval_env,
            policy=policy,
            qf1=qf1,
            qf2=qf2,
            target_qf1=target_qf1,
            target_qf2=target_qf2,
            bonus_network=bonus_network,
            beta=variant['bonus_beta'],
            use_bonus_critic=variant['use_bonus_critic'],
            use_bonus_policy=variant['use_bonus_policy'],
            use_log=variant['use_log'],
            bonus_norm_param=bonus_norm_param,
            rewards_shift_param=rewards_shift_param,
            device=ptu.device,
            **variant['trainer_kwargs'])
        print('Agent of type SAC + additive bonus created')
    elif variant['bonus'] == 'bonus_mlt':
        trainer = SAC_BonusTrainer_Mlt(
            env=eval_env,
            policy=policy,
            qf1=qf1,
            qf2=qf2,
            target_qf1=target_qf1,
            target_qf2=target_qf2,
            bonus_network=bonus_network,
            beta=variant['bonus_beta'],
            use_bonus_critic=variant['use_bonus_critic'],
            use_bonus_policy=variant['use_bonus_policy'],
            bonus_norm_param=bonus_norm_param,
            rewards_shift_param=rewards_shift_param,
            device=ptu.device,
            **variant['trainer_kwargs'])
        print('Agent of type SAC + multiplicative bonus created')

    else:
        raise ValueError('Not implemented error')

    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        batch_rl=True,
        q_learning_alg=True,
        **variant['algorithm_kwargs'])
    algorithm.to(ptu.device)
    algorithm.train()