예제 #1
0
def test_garbled_circuit():
    Alice = garbler(PRS_circuit, {
        'igate_A': None,
        'igate_B': None
    })  # Alice's inputs are A,B chosen randomly
    garbled_table = Alice.garble()
    Bob_choice = None
    while Bob_choice not in [
            'PAPER', 'ROCK', 'SCISSORS', 'LOSE', 'P', 'R', 'S', 'L'
    ]:
        Bob_choice = input(
            'Bob\'choice is PAPER (P), ROCK (R), SCISSORS (S) or LOSE (L) : ')
    C, D = choice_to_bin(Bob_choice)
    OT_receiver = OT.Receiver()
    Bob = evaluator(PRS_circuit, {
        'igate_C': C,
        'igate_D': D
    }, OT_receiver)  # Bob's inputs are C,D chosen by user

    Alice_input_keys = Alice.input_keys(
    )  # @student: Why this does not reveal the inputs of Alice to Bob ?
    Bob_input_keys = Bob.oblivious_transfer(
        Alice
    )  # @student: Why this does not reveal the inputs of Bob to Alice ?
    circuit_outputs = Bob.evaluate(garbled_table, Bob_input_keys,
                                   Alice_input_keys)
    print(result(circuit_outputs['gate_E'], circuit_outputs['gate_F']))
예제 #2
0
 def oblivious_transfer(self, gate_id, c, pk):
     assert gate_id not in self.myinputs and self.circuit.G[
         gate_id].is_circuit_input  # engage OT only on evaluator's inputs
     k0, k1 = self.output_table[gate_id]
     OT_Sender = OT.Sender(k0, k1)
     response = OT_Sender.response(c, pk)
     return response
def build_model(batch, train_data):
    """Assembles the seq2seq model.
    """
    source_embedder = tx.modules.WordEmbedder(
        vocab_size=train_data.source_vocab.size, hparams=config_model.embedder)

    encoder = tx.modules.BidirectionalRNNEncoder(hparams=config_model.encoder)

    enc_outputs, _ = encoder(source_embedder(batch['source_text_ids']))

    target_embedder = tx.modules.WordEmbedder(
        vocab_size=train_data.target_vocab.size, hparams=config_model.embedder)

    decoder = tx.modules.AttentionRNNDecoder(
        memory=tf.concat(enc_outputs, axis=2),
        memory_sequence_length=batch['source_length'],
        vocab_size=train_data.target_vocab.size,
        hparams=config_model.decoder)

    training_outputs, _, _ = decoder(
        decoding_strategy='train_greedy',
        inputs=target_embedder(batch['target_text_ids'][:, :-1]),
        sequence_length=batch['target_length'] - 1)

    # Modify loss
    MLE_loss = tx.losses.sequence_sparse_softmax_cross_entropy(
        labels=batch['target_text_ids'][:, 1:],
        logits=training_outputs.logits,
        sequence_length=batch['target_length'] - 1)

    # TODO: key words matching loss
    tgt_logits = training_outputs.logits
    tgt_words = target_embedder(soft_ids=tgt_logits)
    src_words = source_embedder(ids=batch['source_text_ids'])
    src_words = tf.nn.l2_normalize(src_words, 2, epsilon=1e-12)
    tgt_words = tf.nn.l2_normalize(tgt_words, 2, epsilon=1e-12)

    cosine_cost = 1 - tf.einsum('aij,ajk->aik', src_words,
                                tf.transpose(tgt_words, [0, 2, 1]))
    # pdb.set_trace()
    OT_loss = tf.reduce_mean(OT.IPOT_distance2(cosine_cost))

    Total_loss = MLE_loss + 0.1 * OT_loss

    train_op = tx.core.get_train_op(Total_loss, hparams=config_model.opt)


    start_tokens = tf.ones_like(batch['target_length']) *\
                   train_data.target_vocab.bos_token_id
    beam_search_outputs, _, _ = \
        tx.modules.beam_search_decode(
            decoder_or_cell=decoder,
            embedding=target_embedder,
            start_tokens=start_tokens,
            end_token=train_data.target_vocab.eos_token_id,
            beam_width=config_model.beam_width,
            max_decoding_length=60)

    return train_op, beam_search_outputs
예제 #4
0
def evaluate_garbled_circuit(circuit, myinputs, garbled_table, input_keys,
                             ot_senders):
    """Evaluate a garbled circuit

    :param circuit: circuit to evaluate
    :type circuit: logic_circuit.Circuit
    :param myinputs: known inputs, to be kept hidden
    :type myinputs: dictionnary {gate_id: 0/1}
    :param garbled_table: Table of grabled logic gates
    :type garbled_table: dictionnary {gate_id: 4*[AES_key]}
    :param input_keys: ungarbling keys already known
    :type input_keys: dictionnary {input_gate_id: AES_key}
    :param ot_senders: OT senders to recover missing input keys using myinputs
        values
    :return: State of the evaluated circuit
    :rtype: dictionnary {gate_id: gate_output_value}
    """
    state = input_keys.copy()
    # ---- Input validation ----
    for g_id, g_value in six.iteritems(myinputs):
        assert circuit.g[g_id].kind == "INPUT"
        assert g_value in (0, 1)
    assert set(ot_senders) == set(myinputs)

    for i, b in myinputs.items():
        Bob = OT.Receiver()
        pk, c = Bob.pk, Bob.challenge(b)
        c_0, c_1 = ot_senders[i].response(c, pk)
        state[i] = Bob.decrypt_response(c_0, c_1, b)

    # ---- Recursive ungarbling ----
    def _evaluate_garbled_gate_rec(g_id):
        gate = circuit.g[g_id]
        if gate.in0_id not in state:
            _evaluate_garbled_gate_rec(gate.in0_id)
        if gate.in1_id not in state:
            _evaluate_garbled_gate_rec(gate.in1_id)
        key0 = state[gate.in0_id]
        key1 = state[gate.in1_id]
        # Free XOR trick optimization ; no need to decrypt
        if gate.kind == "XOR" and g_id not in circuit.output_gates:
            state[g_id] = AES_key.from_int(key0.as_int() ^ key1.as_int())
        else:
            for line in garbled_table[g_id]:
                decoded_line = _decode_decryption(
                    key1.decrypt(key0.decrypt(line)))
                if decoded_line is not None:
                    state[g_id] = decoded_line

    for g_id in circuit.output_gates:
        _evaluate_garbled_gate_rec(g_id)

    return state
예제 #5
0
def evaluate_garbled_circuit(circuit, myinputs, garbled_table, input_keys,
                             ot_senders):
    """Evaluate a garbled circuit

    :param circuit: circuit to evaluate
    :type circuit: lib.logic_circuit.Circuit
    :param myinputs: known inputs, to be kept hidden
    :type myinputs: dictionnary {gate_id: 0/1}
    :param garbled_table: Table of garbled logic gates
    :type garbled_table: dictionnary {gate_id: 4*[AES_key]}
    :param input_keys: ungarbling keys already known
    :type input_keys: dictionnary {input_gate_id: AES_key}
    :param ot_senders: OT senders to recover missing input keys using myinputs
                        values
    :return: State of the evaluated circuit
    :rtype: dictionnary {gate_id: gate_output_value}
    """
    # @students: What are the key steps in this function that make it such that
    #             the inputs of Bob are not revealed to Alice ?
    state = input_keys.copy()
    # ---- Input validation ----
    for g_id, g_value in six.iteritems(myinputs):
        assert circuit.g[g_id].kind == "INPUT"
        assert g_value in (0, 1)
    assert set(ot_senders) == set(myinputs)

    # **************************************************************************
    # ---- make OTs, store resulting keys in state ----
    # <to be completed by students>
    for i, b in myinputs.items():
        Bob = OT.Receiver()
        pk, c = Bob.pk, Bob.challenge(b)
        c_0, c_1 = ot_senders[i].response(c, pk)
        state[i] = Bob.decrypt_response(c_0, c_1, b)
    # </to be completed by students>
    # **************************************************************************

    # ---- Recursive ungarbling ----
    def _evaluate_garbled_gate_rec(g_id):
        # **********************************************************************
        # Exercise 2
        # ==========
        # (b) Complete this part of the function.
        # <to be completed by students>
        #  when evaluating a gate, make a shortcut variable first
        gate = circuit.g[g_id]
        #  if gate's inputs are not already evaluated, recursively do so
        if gate.in0_id not in state:
            _evaluate_garbled_gate_rec(gate.in0_id)
        if gate.in1_id not in state:
            _evaluate_garbled_gate_rec(gate.in1_id)
        # now that inputs are evaluated, get their related keys from the state
        key0 = state[gate.in0_id]
        key1 = state[gate.in1_id]
        # and decrypt each line from the received GCT up to the decodable one
        #  (others could not be decoded as only one set of input keys works)
        for line in garbled_table[g_id]:
            decoded_line = _decode_decryption(key1.decrypt(key0.decrypt(line)))
            if decoded_line is not None:
                # at this point, if decoded_line is an AES key, it means that
                #  there are still gates behind to be evaluated :
                # if decoded_line is an integer, then we reached the end of the
                #  circuit
                state[g_id] = decoded_line
        # </to be completed by students>
        # **********************************************************************

    for g_id in circuit.output_gates:
        _evaluate_garbled_gate_rec(g_id)

    return state
예제 #6
0
def garble_circuit(circuit, myinputs):
    """Garble a circuit

    :param circuit: circuit to garble
    :type circuit: logic_circuit.Circuit
    :param myinputs: already known inputs, to be hidden
    :type myinputs: dictionnary {gate_id: 0/1}
    :return: Garbled circuit, ungarbling keys associated to myinputs and OT
        senders for other inputs.
    :rtype: (garbled_table, input_keys, ot_senders)

    - garbled_table: dictionnary {gate_id: 4*[AES_key]}
    - input_keys: dictionnary {input_gate_id: AES_key}
    - ot_senders: dictionnary {input_gate_id: OT.Sender}
    """
    # @students: What are the key steps in this function that make it such that
    #             the inputs of Alice are not revealed to Bob ?

    # Garbling keys (k_0, k_1) for each gate => secret
    output_table = {}
    # Garbled table for each gate => public
    garbled_table = {}
    # Ungarbling keys associated to my inputs => public
    input_keys = {}
    # OT senders for inputs of the other guy => public
    ot_senders = {}

    # ---- Input validation ----
    for g_id, g_value in six.iteritems(myinputs):
        assert circuit.g[g_id].kind == "INPUT"
        assert g_value in (0, 1)

    # ---- Garbling keys generation ----
    for g_id in circuit.g:
        # For output gates, we encrypt the binary output instead of an AES key.
        if not g_id in circuit.output_gates:
            k_0 = AES_key.gen_random()
            k_1 = AES_key.gen_random()
            output_table[g_id] = (k_0, k_1)

    # ---- Garbled tables generation ----
    for g_id, gate in six.iteritems(circuit.g):
        # We already retrieved the values for all the input gates.
        if gate.kind != "INPUT":
            K_0 = output_table[gate.in0_id]  # K_0 = k_00, k_01
            K_1 = output_table[gate.in1_id]  # K_1 = k_10, k_11
            c_list = []
            for i in range(2):
                for j in range(2):
                    # 'real' evaluation of the gate on i,j
                    alpha = Gate.compute_gate(gate.kind, i, j)
                    if g_id in circuit.output_gates:
                        m = _encode_int(alpha)  # 0 or 1
                    else:
                        K = output_table[g_id]
                        m = _encode_key(K[alpha])  # k_0 or k_1 (see above)
                    c = K_1[j].encrypt(m)
                    c_ij = K_0[i].encrypt(c)
                    c_list.append(c_ij)
            # @students: Why is it important to shuffle the list?
            # ANSWER: to avoid leaking keys due to the ordering of c_ij's (the
            #          Garbled Circuit Table values)
            random.shuffle(c_list)
            garbled_table[g_id] = c_list

    # ---- Ungarbling keys generation for my inputs ----
    for g_id, input_val in six.iteritems(myinputs):
        K = output_table[g_id]
        key = K[input_val]  # key = K[i] where i in [0,1] is my input
        input_keys[g_id] = key

    # ---- Oblivious transfer senders ----
    for g_id, gate in six.iteritems(circuit.g):
        if gate.kind == "INPUT" and g_id not in myinputs:
            k0, k1 = output_table[g_id]
            ot_senders[g_id] = OT.Sender(k0, k1)
    return (garbled_table, input_keys, ot_senders)
def main():
    """Entrypoint.
    """
    # Load data
    train_data, dev_data, test_data = data_utils.load_data_numpy(
        config_data.input_dir, config_data.filename_prefix)
    with open(config_data.vocab_file, 'rb') as f:
        id2w = pickle.load(f)
    vocab_size = len(id2w)

    beam_width = config_model.beam_width

    # Create logging
    tx.utils.maybe_create_dir(FLAGS.model_dir)
    logging_file = os.path.join(FLAGS.model_dir, 'logging.txt')
    logger = utils.get_logger(logging_file)
    print('logging file is saved in: %s', logging_file)

    # Build model graph
    encoder_input = tf.placeholder(tf.int64, shape=(None, None))
    decoder_input = tf.placeholder(tf.int64, shape=(None, None))
    batch_size = tf.shape(encoder_input)[0]
    # (text sequence length excluding padding)
    encoder_input_length = tf.reduce_sum(
        1 - tf.cast(tf.equal(encoder_input, 0), tf.int32), axis=1)

    labels = tf.placeholder(tf.int64, shape=(None, None))
    is_target = tf.cast(tf.not_equal(labels, 0), tf.float32)

    global_step = tf.Variable(0, dtype=tf.int64, trainable=False)
    learning_rate = tf.placeholder(tf.float64, shape=(), name='lr')

    # Source word embedding
    src_word_embedder = tx.modules.WordEmbedder(vocab_size=vocab_size,
                                                hparams=config_model.emb)
    src_word_embeds = src_word_embedder(encoder_input)
    src_word_embeds = src_word_embeds * config_model.hidden_dim**0.5

    # Position embedding (shared b/w source and target)
    pos_embedder = tx.modules.SinusoidsPositionEmbedder(
        position_size=config_data.max_decoding_length,
        hparams=config_model.position_embedder_hparams)
    src_seq_len = tf.ones([batch_size], tf.int32) * tf.shape(encoder_input)[1]
    src_pos_embeds = pos_embedder(sequence_length=src_seq_len)

    src_input_embedding = src_word_embeds + src_pos_embeds

    encoder = TransformerEncoder(hparams=config_model.encoder)
    encoder_output = encoder(inputs=src_input_embedding,
                             sequence_length=encoder_input_length)

    # The decoder ties the input word embedding with the output logit layer.
    # As the decoder masks out <PAD>'s embedding, which in effect means
    # <PAD> has all-zero embedding, so here we explicitly set <PAD>'s embedding
    # to all-zero.
    tgt_embedding = tf.concat([
        tf.zeros(shape=[1, src_word_embedder.dim]),
        src_word_embedder.embedding[1:, :]
    ],
                              axis=0)
    tgt_embedder = tx.modules.WordEmbedder(tgt_embedding)
    tgt_word_embeds = tgt_embedder(decoder_input)
    tgt_word_embeds = tgt_word_embeds * config_model.hidden_dim**0.5

    tgt_seq_len = tf.ones([batch_size], tf.int32) * tf.shape(decoder_input)[1]
    tgt_pos_embeds = pos_embedder(sequence_length=tgt_seq_len)

    tgt_input_embedding = tgt_word_embeds + tgt_pos_embeds

    _output_w = tf.transpose(tgt_embedder.embedding, (1, 0))

    decoder = TransformerDecoder(vocab_size=vocab_size,
                                 output_layer=_output_w,
                                 hparams=config_model.decoder)
    # For training
    outputs = decoder(memory=encoder_output,
                      memory_sequence_length=encoder_input_length,
                      inputs=tgt_input_embedding,
                      decoding_strategy='train_greedy',
                      mode=tf.estimator.ModeKeys.TRAIN)
    # Graph matching in Transformer
    _tgt_embedding = tgt_embedder(soft_ids=outputs.logits)

    src_words = tf.nn.l2_normalize(src_word_embeds, 2, epsilon=1e-12)
    tgt_words = tf.nn.l2_normalize(_tgt_embedding, 2, epsilon=1e-12)

    cosine_cost = 1 - tf.einsum('aij,ajk->aik', src_words,
                                tf.transpose(tgt_words, [0, 2, 1]))
    # NOTE: prune
    _beta = 0.2
    minval = tf.reduce_min(cosine_cost)
    maxval = tf.reduce_max(cosine_cost)
    threshold = minval + _beta * (maxval - minval)
    cosine_cost = tf.nn.relu(cosine_cost - threshold)

    # TODO: Gromov wasserstein distance
    Cs = 1 - tf.einsum('aij,ajk->aik', src_words,
                       tf.transpose(src_words, [0, 2, 1]))
    Ct = 1 - tf.einsum('aij,ajk->aik', tgt_words,
                       tf.transpose(tgt_words, [0, 2, 1]))
    Css = OT.prune(Cs)
    Ctt = OT.prune(Ct)

    # OT_loss = tf.reduce_mean(OT.IPOT_distance2(cosine_cost))
    # GW_loss = tf.reduce_mean(OT.GW_distance(Css, Ctt))
    GW_loss, W_loss = OT.FGW_distance(Css, Ctt, cosine_cost)
    FGW_loss = tf.reduce_mean(0.1 * GW_loss + 1 * W_loss)

    mle_loss = transformer_utils.smoothing_cross_entropy(
        outputs.logits, labels, vocab_size, config_model.loss_label_confidence)
    mle_loss = tf.reduce_sum(mle_loss * is_target) / tf.reduce_sum(is_target)

    total_loss = mle_loss + FGW_loss * 0.1

    train_op = tx.core.get_train_op(total_loss,
                                    learning_rate=learning_rate,
                                    global_step=global_step,
                                    hparams=config_model.opt)

    tf.summary.scalar('lr', learning_rate)
    tf.summary.scalar('mle_loss', mle_loss)
    summary_merged = tf.summary.merge_all()

    # For inference (beam-search)
    start_tokens = tf.fill([batch_size], bos_token_id)

    def _embedding_fn(x, y):
        x_w_embed = tgt_embedder(x)
        y_p_embed = pos_embedder(y)
        return x_w_embed * config_model.hidden_dim**0.5 + y_p_embed

    predictions = decoder(memory=encoder_output,
                          memory_sequence_length=encoder_input_length,
                          beam_width=beam_width,
                          length_penalty=config_model.length_penalty,
                          start_tokens=start_tokens,
                          end_token=eos_token_id,
                          embedding=_embedding_fn,
                          max_decoding_length=config_data.max_decoding_length,
                          mode=tf.estimator.ModeKeys.PREDICT)
    # Uses the best sample by beam search
    beam_search_ids = predictions['sample_id'][:, :, 0]

    saver = tf.train.Saver(max_to_keep=5)
    best_results = {'score': 0, 'epoch': -1}

    def _eval_epoch(sess, epoch, mode):
        if mode == 'eval':
            eval_data = dev_data
        elif mode == 'test':
            eval_data = test_data
        else:
            raise ValueError('`mode` should be either "eval" or "test".')

        references, hypotheses = [], []
        bsize = config_data.test_batch_size
        for i in range(0, len(eval_data), bsize):
            sources, targets = zip(*eval_data[i:i + bsize])
            x_block = data_utils.source_pad_concat_convert(sources)
            feed_dict = {
                encoder_input: x_block,
                tx.global_mode(): tf.estimator.ModeKeys.EVAL,
            }
            fetches = {
                'beam_search_ids': beam_search_ids,
            }
            fetches_ = sess.run(fetches, feed_dict=feed_dict)

            hypotheses.extend(h.tolist() for h in fetches_['beam_search_ids'])
            references.extend(r.tolist() for r in targets)
            hypotheses = utils.list_strip_eos(hypotheses, eos_token_id)
            references = utils.list_strip_eos(references, eos_token_id)

        if mode == 'eval':
            # Writes results to files to evaluate BLEU
            # For 'eval' mode, the BLEU is based on token ids (rather than
            # text tokens) and serves only as a surrogate metric to monitor
            # the training process
            fname = os.path.join(FLAGS.model_dir, 'tmp.eval')
            hypotheses = tx.utils.str_join(hypotheses)
            references = tx.utils.str_join(references)
            hyp_fn, ref_fn = tx.utils.write_paired_text(hypotheses,
                                                        references,
                                                        fname,
                                                        mode='s')
            eval_bleu = bleu_wrapper(ref_fn, hyp_fn, case_sensitive=True)
            eval_bleu = 100. * eval_bleu
            logger.info('epoch: %d, eval_bleu %.4f', epoch, eval_bleu)
            print('epoch: %d, eval_bleu %.4f' % (epoch, eval_bleu))

            if eval_bleu > best_results['score']:
                logger.info('epoch: %d, best bleu: %.4f', epoch, eval_bleu)
                best_results['score'] = eval_bleu
                best_results['epoch'] = epoch
                model_path = os.path.join(FLAGS.model_dir, 'best-model.ckpt')
                logger.info('saving model to %s', model_path)
                print('saving model to %s' % model_path)
                saver.save(sess, model_path)

        elif mode == 'test':
            # For 'test' mode, together with the cmds in README.md, BLEU
            # is evaluated based on text tokens, which is the standard metric.
            fname = os.path.join(FLAGS.model_dir, 'test.output')
            hwords, rwords = [], []
            for hyp, ref in zip(hypotheses, references):
                hwords.append([id2w[y] for y in hyp])
                rwords.append([id2w[y] for y in ref])
            hwords = tx.utils.str_join(hwords)
            rwords = tx.utils.str_join(rwords)
            hyp_fn, ref_fn = tx.utils.write_paired_text(hwords,
                                                        rwords,
                                                        fname,
                                                        mode='s',
                                                        src_fname_suffix='hyp',
                                                        tgt_fname_suffix='ref')
            logger.info('Test output writtn to file: %s', hyp_fn)
            print('Test output writtn to file: %s' % hyp_fn)

    def _train_epoch(sess, epoch, step, smry_writer):
        random.shuffle(train_data)
        train_iter = data.iterator.pool(
            train_data,
            config_data.batch_size,
            key=lambda x: (len(x[0]), len(x[1])),
            batch_size_fn=utils.batch_size_fn,
            random_shuffler=data.iterator.RandomShuffler())

        for _, train_batch in enumerate(train_iter):
            in_arrays = data_utils.seq2seq_pad_concat_convert(train_batch)
            feed_dict = {
                encoder_input: in_arrays[0],
                decoder_input: in_arrays[1],
                labels: in_arrays[2],
                learning_rate: utils.get_lr(step, config_model.lr)
            }
            fetches = {
                'step': global_step,
                'train_op': train_op,
                'smry': summary_merged,
                'loss': mle_loss,
            }

            fetches_ = sess.run(fetches, feed_dict=feed_dict)

            step, loss = fetches_['step'], fetches_['loss']
            if step and step % config_data.display_steps == 0:
                logger.info('step: %d, loss: %.4f', step, loss)
                print('step: %d, loss: %.4f' % (step, loss))
                smry_writer.add_summary(fetches_['smry'], global_step=step)

            if step and step % config_data.eval_steps == 0:
                _eval_epoch(sess, epoch, mode='eval')
        return step

    # Run the graph
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())
        sess.run(tf.tables_initializer())

        smry_writer = tf.summary.FileWriter(FLAGS.model_dir, graph=sess.graph)

        if FLAGS.run_mode == 'train_and_evaluate':
            logger.info('Begin running with train_and_evaluate mode')

            if tf.train.latest_checkpoint(FLAGS.model_dir) is not None:
                logger.info('Restore latest checkpoint in %s' %
                            FLAGS.model_dir)
                saver.restore(sess,
                              tf.train.latest_checkpoint(FLAGS.model_dir))

            step = 0
            for epoch in range(config_data.max_train_epoch):
                step = _train_epoch(sess, epoch, step, smry_writer)

        elif FLAGS.run_mode == 'test':
            logger.info('Begin running with test mode')

            logger.info('Restore latest checkpoint in %s' % FLAGS.model_dir)
            saver.restore(sess, tf.train.latest_checkpoint(FLAGS.model_dir))

            _eval_epoch(sess, 0, mode='test')

        else:
            raise ValueError('Unknown mode: {}'.format(FLAGS.run_mode))
예제 #8
0
def garble_circuit(circuit, myinputs):
    """Garble a circuit

    :param circuit: circuit to garble
    :type circuit: logic_circuit.Circuit
    :param myinputs: already known inputs, to be hidden
    :type myinputs: dictionnary {gate_id: 0/1}
    :return: Garbled circuit, ungarbling keys associated to myinputs and OT
        senders for other inputs.
    :rtype: (garbled_table, input_keys, ot_senders)

    - garbled_table: dictionnary {gate_id: 4*[AES_key]}
    - input_keys: dictionnary {input_gate_id: AES_key}
    - ot_senders: dictionnary {input_gate_id: OT.Sender}
    """
    # Garbling keys (k_0, k_1) for each gate => secret
    output_table = {}
    # Garbled table for each gate => public
    garbled_table = {}
    # Ungarbling keys associated to my inputs => public
    input_keys = {}
    # OT senders for inputs of the other guy => public
    ot_senders = {}

    # **************************************************************************
    # Exercise 4
    # ==========
    # we start the Free-XOR optimization (slide 38) here by defining the global
    #  random R value
    R = AES_key.gen_random().as_int()

    # ---- Input validation ----
    for g_id, g_value in six.iteritems(myinputs):
        assert circuit.g[g_id].kind == "INPUT"
        assert g_value in (0, 1)

    # ---- Garbling keys generation ----
    # we modify the original implementation as output keys generation now
    #  depends on output ones, therefore requiring to sort the gates in order of
    #
    for g_id in circuit.ordered_gates():
        # For output gates, we encrypt the binary output instead of an AES key.
        if not g_id in circuit.output_gates:
            g = circuit.g[g_id]
            # key generation ; see lecture "Secure Computation", slide 38
            k_0 = AES_key.from_int(output_table[g.in0_id][0].as_int() ^ \
                                   output_table[g.in1_id][0].as_int()) \
                  if g.kind == "XOR" else AES_key.gen_random()
            k_1 = AES_key.from_int(k_0.as_int() ^ R)
            output_table[g_id] = (k_0, k_1)

    # ---- Garbled tables generation ----
    for g_id, gate in six.iteritems(circuit.g):
        # We already retrieved the values for all the input gates.
        # Free XOR trick optimization ; no encryption/decryption needed for
        #  inner XOR gates, but well for output XOR gates (otherwise, the result
        #  of these will be an AES key while it should be 0 or 1
        if gate.kind != "INPUT" and \
            (gate.kind != "INPUT" or g_id in circuit.output_gates):
            K_0 = output_table[gate.in0_id]  # K_0 = k_00, k_01
            K_1 = output_table[gate.in1_id]  # K_1 = k_10, k_11
            c_list = []
            for i in range(2):
                for j in range(2):
                    # 'real' evaluation of the gate on i,j
                    alpha = Gate.compute_gate(gate.kind, i, j)
                    if g_id in circuit.output_gates:
                        m = _encode_int(alpha)  # 0 or 1
                    else:
                        K = output_table[g_id]
                        m = _encode_key(K[alpha])  # k_0 or k_1 (see above)
                    c = K_1[j].encrypt(m)
                    c_ij = K_0[i].encrypt(c)
                    c_list.append(c_ij)
            random.shuffle(c_list)
            garbled_table[g_id] = c_list

    # ---- Ungarbling keys generation for my inputs ----
    for g_id, input_val in six.iteritems(myinputs):
        K = output_table[g_id]
        key = K[input_val]  # key = K[i] where i in [0,1] is my input
        input_keys[g_id] = key

    # ---- Oblivious transfer senders ----
    for g_id, gate in six.iteritems(circuit.g):
        if gate.kind == "INPUT" and g_id not in myinputs:
            k0, k1 = output_table[g_id]
            ot_senders[g_id] = OT.Sender(k0, k1)
    return (garbled_table, input_keys, ot_senders)