Пример #1
0
def main():
    start = time.time()
    # define them by the parser values
    print("args.full_cross_entropy: ", args.full_cross_entropy)
    print("args.entropy_bonus: ", args.entropy_bonus)
    print("args.discrete_support_values: ", args.discrete_support_values)
    if args.ucb_method == "old":
        ucb_method = "p-UCT-old"
    elif args.ucb_method == "AlphaGo":
        ucb_method = "p-UCT-AlphaGo"
    elif args.ucb_method == "Rosin":
        ucb_method = "p-UCT-Rosin"
    else:
        raise Exception(
            "ucb_method should be one of 'old', 'AlphaGo', 'Rosin'.")

    training_params = dict(
        ucb_C=args.ucb_C,
        discount=args.discount,
        episode_length=args.episode_length,
        max_actions=args.max_actions,
        num_simulations=args.num_simulations,
        device="cpu",  # disable GPU usage 
        n_episodes=args.n_episodes,
        memory_size=args.memory_size,
        batch_size=args.batch_size,
        n_steps=args.n_steps,
        tau=args.tau,
        dirichlet_alpha=args.dirichlet_alpha,
        exploration_fraction=args.exploration_fraction,
        temperature=args.temperature,
        full_cross_entropy=args.full_cross_entropy,
        entropy_bonus=args.entropy_bonus,
        entropy_weight=args.entropy_weight,
        discrete_support_values=args.discrete_support_values,
        ucb_method=ucb_method,
        num_trees=args.num_trees)

    device = "cpu"  # disable GPU usage
    temperature = args.temperature

    network_params = {
        "emb_dim": args.emb_dim,
        "conv_channels": args.conv_channels,
        "conv_layers": args.conv_layers,
        "residual_layers": args.residual_layers,
        "linear_features_in": args.linear_features_in,
        "linear_feature_hidden": args.linear_feature_hidden
    }

    # Environment and simulator
    flags = utils.Flags(env="rtfm:%s-v0" % args.game_name)
    gym_env = utils.create_env(flags)
    featurizer = X.Render()
    game_simulator = mcts.FullTrueSimulator(gym_env, featurizer)
    object_ids = utils.get_object_ids_dict(game_simulator)

    # Networks
    if args.discrete_support_values:
        network_params["support_size"] = args.support_size
        pv_net = mcts.DiscreteSupportPVNet_v3(gym_env,
                                              **network_params).to(device)
        target_net = mcts.DiscreteSupportPVNet_v3(gym_env,
                                                  **network_params).to(device)
    else:
        pv_net = mcts.FixedDynamicsPVNet_v3(gym_env,
                                            **network_params).to(device)
        target_net = mcts.FixedDynamicsPVNet_v3(gym_env,
                                                **network_params).to(device)

    # Share memory of the 'actor' model, i.e. pv_net; it might not even be necessary at this point
    pv_net.share_memory()

    # Init target_net with same parameters of value_net
    for trg_params, params in zip(target_net.parameters(),
                                  pv_net.parameters()):
        trg_params.data.copy_(params.data)

    # Training and optimization
    optimizer = torch.optim.Adam(pv_net.parameters(), lr=args.lr)
    gamma = 10**(-2 / (args.n_episodes - 1)
                 )  # decrease lr of 2 order of magnitude during training
    gamma_T = 10**(-1 / (args.n_episodes - 1)
                   )  # decrease lr of 2 order of magnitude during training
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma)
    replay_buffer = train.HopPolicyValueReplayBuffer(args.memory_size,
                                                     args.discount)

    # Experiment ID
    if args.ID is None:
        ID = gen_PID()
    else:
        ID = args.ID
    print("Experiment ID: ", ID)

    total_rewards = []
    entropies = []
    losses = []
    policy_losses = []
    value_losses = []

    for i in range(args.n_episodes):
        ### Generate experience ###
        t0 = time.time()
        mode = "predict"
        target_net.eval()  # just to make sure
        pv_net.eval()

        results = train.play_rollout_pv_net_hop_mcts(
            args.episode_length,
            object_ids,
            game_simulator,
            args.ucb_C,
            args.discount,
            args.max_actions,
            pv_net,
            args.num_simulations,
            args.num_trees,
            temperature,
            dirichlet_alpha=args.dirichlet_alpha,
            exploration_fraction=args.exploration_fraction,
            ucb_method=ucb_method)
        total_reward, frame_lst, reward_lst, done_lst, action_lst, probs_lst = results
        replay_buffer.store_episode(frame_lst, reward_lst, done_lst,
                                    action_lst, probs_lst)
        total_rewards.append(total_reward)
        rollout_time = (time.time() - t0) / 60
        if (i + 1) % 10 == 0:
            print("\nEpisode %d - Total reward %d " % (i + 1, total_reward))
            print("Rollout time: %.2f" % (rollout_time))

        if i >= args.batch_size:
            ### Update ###
            target_net.eval()  # just to make sure
            frames, target_values, actions, probs = replay_buffer.get_batch(
                args.batch_size, args.n_steps, target_net, device)
            pv_net.train()
            update_results = train.compute_PV_net_update_v1(
                pv_net, frames, target_values, actions, probs, optimizer,
                args.full_cross_entropy, args.entropy_bonus,
                args.entropy_weight, args.discrete_support_values)
            loss, entropy, policy_loss, value_loss = update_results
            scheduler.step()
            temperature = gamma_T * temperature

            # update target network only from time to time
            if (i + 1) % 8 == 0:
                train.update_target_net(target_net, pv_net, args.tau)

            if (i + 1) % 10 == 0:
                print("Loss: %.4f - Policy loss: %.4f - Value loss: %.4f" %
                      (loss, policy_loss, value_loss))
                print("Entropy: %.4f" % entropy)
            losses.append(loss)
            entropies.append(entropy)
            policy_losses.append(policy_loss)
            value_losses.append(value_loss)

        if (i + 1) % 50 == 0:
            # Print update
            print("\nAverage reward over last 50 rollouts: %.2f\n" %
                  (np.mean(total_rewards[-50:])))

        if (i + 1) % args.checkpoint_period == 0:
            # Plot histograms of value stats and save checkpoint
            target_net.eval()
            pv_net.eval()

            # No plots in the script
            #train.plot_value_stats(value_net, target_net, rb, batch_size, n_steps, discount, device)

            d = dict(
                episodes_played=i,
                training_params=training_params,
                object_ids=object_ids,
                pv_net=pv_net,
                target=target_net,
                losses=losses,
                policy_losses=policy_losses,
                value_losses=value_losses,
                total_rewards=total_rewards,
                entropies=entropies,
                optimizer=optimizer,
            )

            experiment_path = "%s/%s/" % (args.save_dir, ID)
            if not os.path.isdir(experiment_path):
                os.mkdir(experiment_path)
            torch.save(d, experiment_path + 'training_dict_%d' % (i + 1))
            torch.save(replay_buffer, experiment_path + 'replay_buffer')
            torch.save(network_params, experiment_path + 'network_params')
            print("Saved checkpoint.")

    end = time.time()
    elapsed = (end - start) / 60
    print("Run took %.1f min." % elapsed)
Пример #2
0
def main():
    start = time.time()
    # define them by the parser values
    training_params = dict(ucb_C=args.ucb_C,
                           discount=args.discount,
                           episode_length=args.episode_length,
                           max_actions=args.max_actions,
                           num_simulations=args.num_simulations,
                           device=args.device,
                           n_episodes=args.n_episodes,
                           memory_size=args.memory_size,
                           batch_size=args.batch_size,
                           n_steps=args.n_steps,
                           tau=args.tau)

    device = args.device

    # Environment and simulator
    flags = utils.Flags(env="rtfm:groups_simple_stationary-v0")
    gym_env = utils.create_env(flags)
    featurizer = X.Render()
    game_simulator = mcts.FullTrueSimulator(gym_env, featurizer)
    object_ids = utils.get_object_ids_dict(game_simulator)

    # Networks
    value_net = mcts.FixedDynamicsValueNet_v2(gym_env).to(device)
    target_net = mcts.FixedDynamicsValueNet_v2(gym_env).to(device)
    # Init target_net with same parameters of value_net
    for trg_params, params in zip(target_net.parameters(),
                                  value_net.parameters()):
        trg_params.data.copy_(params.data)

    # Training and optimization
    optimizer = torch.optim.Adam(value_net.parameters(), lr=args.lr)
    gamma = 10**(-2 / (args.n_episodes / args.net_update_period - 1)
                 )  # decrease lr of 2 order of magnitude during training
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma)
    loss_fn = F.mse_loss
    rb = train.nStepsReplayBuffer(args.memory_size, args.discount)

    # Experiment ID
    if args.ID is None:
        ID = gen_PID()
    else:
        ID = args.ID
    print("Experiment ID: ", ID)

    total_rewards = []
    losses = []
    for i in range(args.n_episodes):
        ### Generate experience ###
        t0 = time.time()
        value_net.eval()
        total_reward, frame_lst, reward_lst, done_lst = train.play_rollout_value_net(
            value_net,
            game_simulator,
            args.episode_length,
            args.ucb_C,
            args.discount,
            args.max_actions,
            args.num_simulations,
            mode="predict",
            bootstrap="no")
        t1 = time.time()
        total_rewards.append(total_reward)
        print("\nEpisode %d - Total reward %d" % (i + 1, total_reward))
        rollout_time = (t1 - t0) / 60
        print("Rollout time: %.2f" % (rollout_time))
        rb.store_episode(frame_lst, reward_lst, done_lst)

        ### Train value_net ###

        try:
            # update value network all the time
            if (i + 1) % args.net_update_period == 0:
                target_net.eval()
                frames, targets = rb.get_batch(args.batch_size, args.n_steps,
                                               args.discount, target_net,
                                               device)
                value_net.train()
                loss = train.compute_update_v1(value_net, frames, targets,
                                               loss_fn, optimizer)
                scheduler.step()
                print("Loss: %.4f" % loss)
                losses.append(loss)
            # update target network only from time to time
            if (i + 1) % args.target_update_period == 0:
                train.update_target_net(target_net, value_net, args.tau)

        except:
            pass

        if (i + 1) % 50 == 0:
            # Print update
            print("\nAverage reward over last 50 rollouts: %.2f\n" %
                  (np.mean(total_rewards[-50:])))

        if (i + 1) % args.checkpoint_period == 0:
            # Plot histograms of value stats and save checkpoint
            target_net.eval()
            value_net.eval()

            # No plots in the script
            #train.plot_value_stats(value_net, target_net, rb, batch_size, n_steps, discount, device)

            d = dict(episodes_played=i,
                     training_params=training_params,
                     object_ids=object_ids,
                     value_net=value_net,
                     target_net=target_net,
                     rb=rb,
                     losses=losses,
                     total_rewards=total_rewards)

            experiment_path = "./%s/%s/" % (args.save_dir, ID)
            if not os.path.isdir(experiment_path):
                os.mkdir(experiment_path)
            torch.save(d, experiment_path + 'training_dict_%d' % (i + 1))
            print("Saved checkpoint.")

    end = time.time()
    elapsed = (end - start) / 60
    print("Run took %.1f min." % elapsed)