Esempio n. 1
0
def get_modules(model_params, action_spec):
    """Get agent modules."""
    model_params, n_q_fns = model_params
    if len(model_params) == 1:
        model_params = tuple([model_params[0]] * 3)
    elif len(model_params) < 3:
        raise ValueError('Bad model parameters %s.' % model_params)

    def q_net_factory():
        return networks.CriticNetwork(fc_layer_params=model_params[0])

    def p_net_factory():
        return networks.ActorNetwork(action_spec,
                                     fc_layer_params=model_params[1])

    def vae_net_factory():
        return networks.BCQVAENetwork(action_spec,
                                      fc_layer_params=model_params[2])

    modules = utils.Flags(
        q_net_factory=q_net_factory,
        p_net_factory=p_net_factory,
        vae_net_factory=vae_net_factory,
        n_q_fns=n_q_fns,
    )
    return modules
Esempio n. 2
0
    def recv(self, packet):
        '''
        Handels all recived packets of this connection as per the RFC specification
        Args:
            packet (TCPPacket): A TCP packet of this connection
        '''
        data = packet.data
        self.acknm += len(data)

        # Print the sent data in ASCII
        if len(data) != 0:
            print("\33[1mThe internet said:\33[33m ",
                  "".join([chr(d) for d in data]) + '\33[0m')

        if self.state == State.CLOSED:
            print("\33[31m\33[1mError:\33[0m\33[1m connection doesn't exist.")
            return

        if self.state == State.LISTEN:
            if packet.flags == utils.Flags.flag('syn'):
                # Send SYN,ACK
                self.state = State.SYN_RCVD
                self.acknm = int.from_bytes(packet.seq_num, 'big') + 1
                # SYN, ACK (TODO: make simpler to do)
                flags = utils.Flags(0x12)
                snd = self.mkpkt(flags=flags)
                self.send(snd)
                self.sqnm += 1
                # self.send(snd)
            return

        if self.state == State.SYN_RCVD:
            if packet.flags == utils.Flags.flag('ack'):
                self.state = State.ESTAB
                print(f'\n\33[1m\33[32mCnnection Established!!\33[0m')
            return

        if self.state == State.ESTAB:
            flags = utils.Flags('ack')  # ACK
            print('flag=', flags)

            snd = self.mkpkt(flags=flags)
            if packet.flags & utils.Flags.flag('fin'):
                print(f'\n\33[1m\33[31mClosing connection!!\33[0m')
                self.state = State.CLOSE_WAIT
                self.close()
            self.send(snd)
Esempio n. 3
0
def main():
    start = time.time()
    # define them by the parser values
    print("args.full_cross_entropy: ", args.full_cross_entropy)
    print("args.entropy_bonus: ", args.entropy_bonus)
    print("args.discrete_support_values: ", args.discrete_support_values)
    if args.ucb_method == "old":
        ucb_method = "p-UCT-old"
    elif args.ucb_method == "AlphaGo":
        ucb_method = "p-UCT-AlphaGo"
    elif args.ucb_method == "Rosin":
        ucb_method = "p-UCT-Rosin"
    else:
        raise Exception(
            "ucb_method should be one of 'old', 'AlphaGo', 'Rosin'.")

    training_params = dict(
        ucb_C=args.ucb_C,
        discount=args.discount,
        episode_length=args.episode_length,
        max_actions=args.max_actions,
        num_simulations=args.num_simulations,
        device="cpu",  # disable GPU usage 
        n_episodes=args.n_episodes,
        memory_size=args.memory_size,
        batch_size=args.batch_size,
        n_steps=args.n_steps,
        tau=args.tau,
        dirichlet_alpha=args.dirichlet_alpha,
        exploration_fraction=args.exploration_fraction,
        temperature=args.temperature,
        full_cross_entropy=args.full_cross_entropy,
        entropy_bonus=args.entropy_bonus,
        entropy_weight=args.entropy_weight,
        discrete_support_values=args.discrete_support_values,
        ucb_method=ucb_method,
        num_trees=args.num_trees)

    device = "cpu"  # disable GPU usage
    temperature = args.temperature

    network_params = {
        "emb_dim": args.emb_dim,
        "conv_channels": args.conv_channels,
        "conv_layers": args.conv_layers,
        "residual_layers": args.residual_layers,
        "linear_features_in": args.linear_features_in,
        "linear_feature_hidden": args.linear_feature_hidden
    }

    # Environment and simulator
    flags = utils.Flags(env="rtfm:%s-v0" % args.game_name)
    gym_env = utils.create_env(flags)
    featurizer = X.Render()
    game_simulator = mcts.FullTrueSimulator(gym_env, featurizer)
    object_ids = utils.get_object_ids_dict(game_simulator)

    # Networks
    if args.discrete_support_values:
        network_params["support_size"] = args.support_size
        pv_net = mcts.DiscreteSupportPVNet_v3(gym_env,
                                              **network_params).to(device)
        target_net = mcts.DiscreteSupportPVNet_v3(gym_env,
                                                  **network_params).to(device)
    else:
        pv_net = mcts.FixedDynamicsPVNet_v3(gym_env,
                                            **network_params).to(device)
        target_net = mcts.FixedDynamicsPVNet_v3(gym_env,
                                                **network_params).to(device)

    # Share memory of the 'actor' model, i.e. pv_net; it might not even be necessary at this point
    pv_net.share_memory()

    # Init target_net with same parameters of value_net
    for trg_params, params in zip(target_net.parameters(),
                                  pv_net.parameters()):
        trg_params.data.copy_(params.data)

    # Training and optimization
    optimizer = torch.optim.Adam(pv_net.parameters(), lr=args.lr)
    gamma = 10**(-2 / (args.n_episodes - 1)
                 )  # decrease lr of 2 order of magnitude during training
    gamma_T = 10**(-1 / (args.n_episodes - 1)
                   )  # decrease lr of 2 order of magnitude during training
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma)
    replay_buffer = train.HopPolicyValueReplayBuffer(args.memory_size,
                                                     args.discount)

    # Experiment ID
    if args.ID is None:
        ID = gen_PID()
    else:
        ID = args.ID
    print("Experiment ID: ", ID)

    total_rewards = []
    entropies = []
    losses = []
    policy_losses = []
    value_losses = []

    for i in range(args.n_episodes):
        ### Generate experience ###
        t0 = time.time()
        mode = "predict"
        target_net.eval()  # just to make sure
        pv_net.eval()

        results = train.play_rollout_pv_net_hop_mcts(
            args.episode_length,
            object_ids,
            game_simulator,
            args.ucb_C,
            args.discount,
            args.max_actions,
            pv_net,
            args.num_simulations,
            args.num_trees,
            temperature,
            dirichlet_alpha=args.dirichlet_alpha,
            exploration_fraction=args.exploration_fraction,
            ucb_method=ucb_method)
        total_reward, frame_lst, reward_lst, done_lst, action_lst, probs_lst = results
        replay_buffer.store_episode(frame_lst, reward_lst, done_lst,
                                    action_lst, probs_lst)
        total_rewards.append(total_reward)
        rollout_time = (time.time() - t0) / 60
        if (i + 1) % 10 == 0:
            print("\nEpisode %d - Total reward %d " % (i + 1, total_reward))
            print("Rollout time: %.2f" % (rollout_time))

        if i >= args.batch_size:
            ### Update ###
            target_net.eval()  # just to make sure
            frames, target_values, actions, probs = replay_buffer.get_batch(
                args.batch_size, args.n_steps, target_net, device)
            pv_net.train()
            update_results = train.compute_PV_net_update_v1(
                pv_net, frames, target_values, actions, probs, optimizer,
                args.full_cross_entropy, args.entropy_bonus,
                args.entropy_weight, args.discrete_support_values)
            loss, entropy, policy_loss, value_loss = update_results
            scheduler.step()
            temperature = gamma_T * temperature

            # update target network only from time to time
            if (i + 1) % 8 == 0:
                train.update_target_net(target_net, pv_net, args.tau)

            if (i + 1) % 10 == 0:
                print("Loss: %.4f - Policy loss: %.4f - Value loss: %.4f" %
                      (loss, policy_loss, value_loss))
                print("Entropy: %.4f" % entropy)
            losses.append(loss)
            entropies.append(entropy)
            policy_losses.append(policy_loss)
            value_losses.append(value_loss)

        if (i + 1) % 50 == 0:
            # Print update
            print("\nAverage reward over last 50 rollouts: %.2f\n" %
                  (np.mean(total_rewards[-50:])))

        if (i + 1) % args.checkpoint_period == 0:
            # Plot histograms of value stats and save checkpoint
            target_net.eval()
            pv_net.eval()

            # No plots in the script
            #train.plot_value_stats(value_net, target_net, rb, batch_size, n_steps, discount, device)

            d = dict(
                episodes_played=i,
                training_params=training_params,
                object_ids=object_ids,
                pv_net=pv_net,
                target=target_net,
                losses=losses,
                policy_losses=policy_losses,
                value_losses=value_losses,
                total_rewards=total_rewards,
                entropies=entropies,
                optimizer=optimizer,
            )

            experiment_path = "%s/%s/" % (args.save_dir, ID)
            if not os.path.isdir(experiment_path):
                os.mkdir(experiment_path)
            torch.save(d, experiment_path + 'training_dict_%d' % (i + 1))
            torch.save(replay_buffer, experiment_path + 'replay_buffer')
            torch.save(network_params, experiment_path + 'network_params')
            print("Saved checkpoint.")

    end = time.time()
    elapsed = (end - start) / 60
    print("Run took %.1f min." % elapsed)
Esempio n. 4
0
def main():
    start = time.time()
    # define them by the parser values
    training_params = dict(ucb_C=args.ucb_C,
                           discount=args.discount,
                           episode_length=args.episode_length,
                           max_actions=args.max_actions,
                           num_simulations=args.num_simulations,
                           device=args.device,
                           n_episodes=args.n_episodes,
                           memory_size=args.memory_size,
                           batch_size=args.batch_size,
                           n_steps=args.n_steps,
                           tau=args.tau)

    device = args.device

    # Environment and simulator
    flags = utils.Flags(env="rtfm:groups_simple_stationary-v0")
    gym_env = utils.create_env(flags)
    featurizer = X.Render()
    game_simulator = mcts.FullTrueSimulator(gym_env, featurizer)
    object_ids = utils.get_object_ids_dict(game_simulator)

    # Networks
    value_net = mcts.FixedDynamicsValueNet_v2(gym_env).to(device)
    target_net = mcts.FixedDynamicsValueNet_v2(gym_env).to(device)
    # Init target_net with same parameters of value_net
    for trg_params, params in zip(target_net.parameters(),
                                  value_net.parameters()):
        trg_params.data.copy_(params.data)

    # Training and optimization
    optimizer = torch.optim.Adam(value_net.parameters(), lr=args.lr)
    gamma = 10**(-2 / (args.n_episodes / args.net_update_period - 1)
                 )  # decrease lr of 2 order of magnitude during training
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma)
    loss_fn = F.mse_loss
    rb = train.nStepsReplayBuffer(args.memory_size, args.discount)

    # Experiment ID
    if args.ID is None:
        ID = gen_PID()
    else:
        ID = args.ID
    print("Experiment ID: ", ID)

    total_rewards = []
    losses = []
    for i in range(args.n_episodes):
        ### Generate experience ###
        t0 = time.time()
        value_net.eval()
        total_reward, frame_lst, reward_lst, done_lst = train.play_rollout_value_net(
            value_net,
            game_simulator,
            args.episode_length,
            args.ucb_C,
            args.discount,
            args.max_actions,
            args.num_simulations,
            mode="predict",
            bootstrap="no")
        t1 = time.time()
        total_rewards.append(total_reward)
        print("\nEpisode %d - Total reward %d" % (i + 1, total_reward))
        rollout_time = (t1 - t0) / 60
        print("Rollout time: %.2f" % (rollout_time))
        rb.store_episode(frame_lst, reward_lst, done_lst)

        ### Train value_net ###

        try:
            # update value network all the time
            if (i + 1) % args.net_update_period == 0:
                target_net.eval()
                frames, targets = rb.get_batch(args.batch_size, args.n_steps,
                                               args.discount, target_net,
                                               device)
                value_net.train()
                loss = train.compute_update_v1(value_net, frames, targets,
                                               loss_fn, optimizer)
                scheduler.step()
                print("Loss: %.4f" % loss)
                losses.append(loss)
            # update target network only from time to time
            if (i + 1) % args.target_update_period == 0:
                train.update_target_net(target_net, value_net, args.tau)

        except:
            pass

        if (i + 1) % 50 == 0:
            # Print update
            print("\nAverage reward over last 50 rollouts: %.2f\n" %
                  (np.mean(total_rewards[-50:])))

        if (i + 1) % args.checkpoint_period == 0:
            # Plot histograms of value stats and save checkpoint
            target_net.eval()
            value_net.eval()

            # No plots in the script
            #train.plot_value_stats(value_net, target_net, rb, batch_size, n_steps, discount, device)

            d = dict(episodes_played=i,
                     training_params=training_params,
                     object_ids=object_ids,
                     value_net=value_net,
                     target_net=target_net,
                     rb=rb,
                     losses=losses,
                     total_rewards=total_rewards)

            experiment_path = "./%s/%s/" % (args.save_dir, ID)
            if not os.path.isdir(experiment_path):
                os.mkdir(experiment_path)
            torch.save(d, experiment_path + 'training_dict_%d' % (i + 1))
            print("Saved checkpoint.")

    end = time.time()
    elapsed = (end - start) / 60
    print("Run took %.1f min." % elapsed)
Esempio n. 5
0
            print(
                f'\n\33[1mRecived \33[35m{utils.prtcls[iparse.prtcl]}\33[39m packet,'
                ' ignoring...\33[0m (you can use -v to display all IPv4 packets)'
            )
            continue

        # IP payload (iclued all of TCP)
        idata = iparse.data
        # Parse TCP packet
        tcparse = parse.tcp(idata)

        # Some pretty prints
        utils.print_pac(iparse, tcparse)

        # Make the flags from the packet into a Flags object and print
        flags = utils.Flags(tcparse.flags)
        print(flags)

        # ------------ Manage TCP connections ---------------
        quad = tcp.Quad(iparse.srcip, tcparse.src_port, iparse.dstip,
                        tcparse.dst_port)

        # Check if connection already exists, if not, create one
        conn_exists = False
        conn = None
        for con in conns:
            if con.quad == quad:  # The packet is for an existing connection
                conn_exists = True
                conn = con
                break
        if not conn_exists:  # Start a new connection
Esempio n. 6
0
 def mkpkt(self, data=b'', flags=0):
     '''Wrapper for utils.mkkpkt to automatically use this object's properties'''
     if isinstance(flags, int) or isinstance(flags, str):
         flags = utils.Flags(flags)
     return utils.mkpkt(data, self.quad, flags, self.acknm, self.sqnm)
Esempio n. 7
0
    print("You need root privileges to run this application!")
    exit(-1)

# Clear the screen
os.system("clear")

# Checking debug flag status (-d or --debug)
debug = utils.is_debug()

# This sets the logger treshold level
if debug:
    log_level = "DEBUG"
else:
    log_level = "INFO"

flags = utils.Flags()
dropping_policy = flags.dropping_policy
log_level = flags.log_level
# convertire la variabile debug con log_level e fare i controlli
# if log_level == "DEBUG" or log_level == "ALL"?
# oppure aggiungere la variabile self.debug ad utils se si vuole distinguere
# modalità debug da log debug
debug = flags.debug
print(flags)

# Indispensable objects instantiation
log = logger.Log(log_file, log_level)
shield = analysis.Shield(log, queue_number, services_file)
handling = packet_handling.PacketHandling(log, shield, debug, dropping_policy)

# Optional objects instantiation: comment them to disable