コード例 #1
0
 def to(self, device: torch.device):
     """Method to move all (nested) tensors of the batch to a specific device.
     This operation doest not change the original batch element and returns a new Batch instance.
     """
     self.sender_input = move_to(self.sender_input, device)
     self.labels = move_to(self.labels, device)
     self.receiver_input = move_to(self.receiver_input, device)
     self.aux_input = move_to(self.aux_input, device)
     return self
コード例 #2
0
def evaluate(
    game: nn.Module,
    data: torch.utils.data.DataLoader,
):
    if torch.cuda.is_available():
        game.cuda()
    game.eval()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    mean_loss = 0.0
    interactions = []
    n_batches = 0
    soft_accuracy, game_accuracy = 0.0, 0.0
    with torch.no_grad():
        for batch in data:
            batch = move_to(batch, device)
            optimized_loss, interaction = game(*batch)

            interaction = interaction.to("cpu")
            interactions.append(interaction)

            mean_loss += optimized_loss
            soft_accuracy += interaction.aux["acc"].mean().item()
            game_accuracy += interaction.aux["game_acc"].mean().item()
            n_batches += 1
            if n_batches % 10 == 0:
                print(f"finished batch {n_batches}")
            # when running kmeans, we first feed the train data.
            # given we're clustering only a subset of 100_000 elements from the
            # training data we can stop after 128 (bsz) X 800 (batches) = 102_400 samples
            if n_batches == 800:
                break

    print(f"processed {n_batches} batches in total")
    mean_loss /= n_batches
    soft_accuracy /= n_batches
    game_accuracy /= n_batches
    full_interaction = Interaction.from_iterable(interactions)

    return mean_loss, soft_accuracy, game_accuracy, full_interaction
コード例 #3
0
def evaluate_test_set(
    game: nn.Module,
    data: torch.utils.data.DataLoader,
    k_means_clusters: KMeans,
    num_clusters: int
):
    if torch.cuda.is_available():
        game.cuda()
    game.eval()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    logging_strategy = LoggingStrategy(False, False, True, True, True, False)

    mean_loss = 0.0
    interactions = []
    n_batches = 0
    soft_accuracy, game_accuracy = 0.0, 0.0
    for batch in data:
        batch = move_to(batch, device)
        (x_i, x_j), labels = batch
        if torch.cuda.is_available():
            x_i = x_i.cuda()
            x_j = x_j.cuda()

        with torch.no_grad():
            sender_encoded_input, receiver_encoded_input = game.vision_module(x_i, x_j)
            message, message_like, resnet_output_sender = game.game.sender(sender_encoded_input, sender=True)

            resnet_output_sender_to_predict = resnet_output_sender.cpu().numpy()
            k_means_labels = torch.from_numpy(
                k_means_clusters.predict(resnet_output_sender_to_predict)
            ).to(device=message_like.device, dtype=torch.int64)

            one_hot_k_means_labels = torch.zeros((message_like.size()[0], num_clusters), device=message_like.device)
            one_hot_k_means_labels.scatter_(1, k_means_labels.view(-1, 1), 1)

            receiver_output, resnet_output_recv = game.game.receiver(message, receiver_encoded_input)

            loss, aux_info = game.game.loss(
                sender_encoded_input, message, receiver_encoded_input, receiver_output, labels
            )

            if hasattr(game.game.sender, "temperature"):
                if isinstance(game.game.sender.temperature, torch.nn.Parameter):
                    temperature = game.game.sender.temperature.detach()
                else:
                    temperature = torch.Tensor([game.game.sender.temperature])
                aux_info["temperature"] = temperature

            aux_info["message_like"] = message_like
            aux_info["kmeans"] = one_hot_k_means_labels
            aux_info["resnet_output_sender"] = resnet_output_sender
            aux_info["resnet_output_recv"] = resnet_output_recv

            interaction = logging_strategy.filtered_interaction(
                sender_input=sender_encoded_input,
                receiver_input=receiver_encoded_input,
                labels=labels,
                receiver_output=receiver_output.detach(),
                message=message,
                message_length=torch.ones(message_like.shape[0]),
                aux=aux_info,
            )

            interaction = interaction.to("cpu")
            interactions.append(interaction)

            mean_loss += loss.mean()
            soft_accuracy += interaction.aux['acc'].mean().item()
            game_accuracy += interaction.aux['game_acc'].mean().item()
            n_batches += 1
            if n_batches % 10 == 0:
                print(f"finished batch {n_batches}")

    print(f"processed {n_batches} in total")
    mean_loss /= n_batches
    soft_accuracy /= n_batches
    game_accuracy /= n_batches
    full_interaction = Interaction.from_iterable(interactions)

    return mean_loss, soft_accuracy, game_accuracy, full_interaction
コード例 #4
0
ファイル: train.py プロジェクト: cjlovering/EGG
        train_data=train_data,
        validation_data=validation_data,
        callbacks=callbacks,
    )  # validation_data=validation_data,
    trainer.train(n_epochs=opts.n_epochs)

    if opts.evaluate:
        is_gs = "gs" in opts.mode
        sender_inputs, messages, receiver_inputs, receiver_outputs, labels = core.dump_sender_receiver(
            game, test_data, is_gs, variable_length=True, device=device)

        _, _, _, train_receiver_outputs, train_labels = core.dump_sender_receiver(
            game, train_data, is_gs, variable_length=True, device=device)

        # Test
        receiver_outputs = move_to(receiver_outputs, device)
        labels = move_to(labels, device)
        receiver_outputs = torch.stack(receiver_outputs)
        labels = torch.stack(labels)

        output_is_vector = opts.mode.lower() in set(
            ["gs-hard", "gs", "rf-deterministic"])

        if output_is_vector:
            tensor_accuracy = receiver_outputs.squeeze().argmax(
                dim=1) == labels
        else:
            tensor_accuracy = receiver_outputs == labels
        accuracy = torch.mean(tensor_accuracy.float()).item()

        # Train
コード例 #5
0
def main(params):
    opts = get_params(params)

    device = torch.device("cuda" if opts.cuda else "cpu")

    data_loader = VectorsLoader(
        perceptual_dimensions=opts.perceptual_dimensions,
        n_distractors=opts.n_distractors,
        batch_size=opts.batch_size,
        train_samples=opts.train_samples,
        validation_samples=opts.validation_samples,
        test_samples=opts.test_samples,
        shuffle_train_data=opts.shuffle_train_data,
        dump_data_folder=opts.dump_data_folder,
        load_data_path=opts.load_data_path,
        seed=opts.data_seed,
    )

    train_data, validation_data, test_data = data_loader.get_iterators()

    data_loader.upd_cl_options(opts)

    if opts.max_len > 1:
        baseline_msg = 'Cannot yet compute "smart" baseline value for messages of length greater than 1'
    else:
        baseline_msg = (
            f"\n| Baselines measures with {opts.n_distractors} distractors and messages of max_len = {opts.max_len}:\n"
            f"| Dummy random baseline: accuracy = {1 / (opts.n_distractors + 1)}\n"
        )
        if -1 not in opts.perceptual_dimensions:
            baseline_msg += f'| "Smart" baseline with perceptual_dimensions {opts.perceptual_dimensions} = {compute_baseline_accuracy(opts.n_distractors, opts.max_len, *opts.perceptual_dimensions)}\n'
        else:
            baseline_msg += f'| Data was loaded froman external file, thus no perceptual_dimension vector was provided, "smart baseline" cannot be computed\n'

    print(baseline_msg)

    sender = Sender(n_features=data_loader.n_features,
                    n_hidden=opts.sender_hidden)

    receiver = Receiver(n_features=data_loader.n_features,
                        linear_units=opts.receiver_hidden)

    if opts.mode.lower() == "gs":
        sender = core.RnnSenderGS(
            sender,
            opts.vocab_size,
            opts.sender_embedding,
            opts.sender_hidden,
            cell=opts.sender_cell,
            max_len=opts.max_len,
            temperature=opts.temperature,
        )

        receiver = core.RnnReceiverGS(
            receiver,
            opts.vocab_size,
            opts.receiver_embedding,
            opts.receiver_hidden,
            cell=opts.receiver_cell,
        )

        game = core.SenderReceiverRnnGS(sender, receiver, loss)
    else:
        raise NotImplementedError(f"Unknown training mode, {opts.mode}")

    optimizer = torch.optim.Adam([
        {
            "params": game.sender.parameters(),
            "lr": opts.sender_lr
        },
        {
            "params": game.receiver.parameters(),
            "lr": opts.receiver_lr
        },
    ])
    callbacks = [core.ConsoleLogger(as_json=True)]
    if opts.mode.lower() == "gs":
        callbacks.append(
            core.TemperatureUpdater(agent=sender, decay=0.9, minimum=0.1))
    trainer = core.Trainer(
        game=game,
        optimizer=optimizer,
        train_data=train_data,
        validation_data=validation_data,
        callbacks=callbacks,
    )
    trainer.train(n_epochs=opts.n_epochs)

    if opts.evaluate:
        is_gs = opts.mode == "gs"
        (
            sender_inputs,
            messages,
            receiver_inputs,
            receiver_outputs,
            labels,
        ) = dump_sender_receiver(game,
                                 test_data,
                                 is_gs,
                                 variable_length=True,
                                 device=device)

        receiver_outputs = move_to(receiver_outputs, device)
        labels = move_to(labels, device)

        receiver_outputs = torch.stack(receiver_outputs)
        labels = torch.stack(labels)

        tensor_accuracy = receiver_outputs.argmax(dim=1) == labels
        accuracy = torch.mean(tensor_accuracy.float()).item()

        unique_dict = {}

        for elem in sender_inputs:
            target = ""
            for dim in elem:
                target += f"{str(int(dim.item()))}-"
            target = target[:-1]
            if target not in unique_dict:
                unique_dict[target] = True

        print(f"| Accuracy on test set: {accuracy}")

        compute_mi_input_msgs(sender_inputs, messages)

        print(f"entropy sender inputs {entropy(sender_inputs)}")
        print(f"mi sender inputs msgs {mutual_info(sender_inputs, messages)}")

        if opts.dump_msg_folder:
            opts.dump_msg_folder.mkdir(exist_ok=True)
            msg_dict = {}

            output_msg = (
                f"messages_{opts.perceptual_dimensions}_vocab_{opts.vocab_size}"
                f"_maxlen_{opts.max_len}_bsize_{opts.batch_size}"
                f"_n_distractors_{opts.n_distractors}_train_size_{opts.train_samples}"
                f"_valid_size_{opts.validation_samples}_test_size_{opts.test_samples}"
                f"_slr_{opts.sender_lr}_rlr_{opts.receiver_lr}_shidden_{opts.sender_hidden}"
                f"_rhidden_{opts.receiver_hidden}_semb_{opts.sender_embedding}"
                f"_remb_{opts.receiver_embedding}_mode_{opts.mode}"
                f"_scell_{opts.sender_cell}_rcell_{opts.receiver_cell}.msg")

            output_file = opts.dump_msg_folder / output_msg
            with open(output_file, "w") as f:
                f.write(f"{opts}\n")
                for (
                        sender_input,
                        message,
                        receiver_input,
                        receiver_output,
                        label,
                ) in zip(sender_inputs, messages, receiver_inputs,
                         receiver_outputs, labels):
                    sender_input = ",".join(map(str, sender_input.tolist()))
                    message = ",".join(map(str, message.tolist()))
                    distractors_list = receiver_input.tolist()
                    receiver_input = "; ".join([
                        ",".join(map(str, elem)) for elem in distractors_list
                    ])
                    if is_gs:
                        receiver_output = receiver_output.argmax()
                    f.write(
                        f"{sender_input} -> {receiver_input} -> {message} -> {receiver_output} (label={label.item()})\n"
                    )

                    if message in msg_dict:
                        msg_dict[message] += 1
                    else:
                        msg_dict[message] = 1

                sorted_msgs = sorted(msg_dict.items(),
                                     key=operator.itemgetter(1),
                                     reverse=True)
                f.write(
                    f"\nUnique target vectors seen by sender: {len(unique_dict.keys())}\n"
                )
                f.write(
                    f"Unique messages produced by sender: {len(msg_dict.keys())}\n"
                )
                f.write(f"Messagses: 'msg' : msg_count: {str(sorted_msgs)}\n")
                f.write(f"\nAccuracy: {accuracy}")
コード例 #6
0
# ================ First play the recon game, get the sender ==================
ref_game = SymbolicReferGame()
ref_game.train(10000)  # the argument is the number of training epochs.

ref_game.game.eval()
recon_game = SymbolicReconGame(
    training_log='~/GitWS/GameBias/log/recon_train_temp.txt')
optimizer = core.build_optimizer(recon_game.game.receiver.parameters())
train_loss = []
test_loss = []

for i in range(5000):
    acc_list = []
    for batch_idx, (target, label) in enumerate(recon_game.train_loader):
        optimizer.zero_grad()
        target = move_to(target, recon_game.trainer.device)
        label = move_to(label, recon_game.trainer.device)

        msg, _ = ref_game.sender(target)
        rec_out = recon_game.receiver(msg.detach())
        loss, _ = recon_game.loss(target, msg, msg, rec_out, label)
        acc_list.append(loss.mean().item())
        loss.sum().backward()
        optimizer.step()
    print('train loss:', np.mean(acc_list))
    train_loss.append(np.mean(acc_list))

    acc_list = []
    for batch_idx, (target, label) in enumerate(recon_game.test_loader):
        recon_game.game.eval()
        target = move_to(target, recon_game.trainer.device)
コード例 #7
0
ファイル: util.py プロジェクト: wedddy0707/EGG
def dump_sender_receiver(
    game: torch.nn.Module,
    dataset: "torch.utils.data.DataLoader",
    gs: bool,
    variable_length: bool,
    device: Optional[torch.device] = None,
):
    """
    A tool to dump the interaction between Sender and Receiver
    :param game: A Game instance
    :param dataset: Dataset of inputs to be used when analyzing the communication
    :param gs: whether Gumbel-Softmax relaxation was used during training
    :param variable_length: whether variable-length communication is used
    :param device: device (e.g. 'cuda') to be used
    :return:
    """
    train_state = game.training  # persist so we restore it back
    game.eval()

    device = device if device is not None else common_opts.device

    sender_inputs, messages, receiver_inputs, receiver_outputs = [], [], [], []
    labels = []

    with torch.no_grad():
        for batch in dataset:
            # by agreement, each batch is (sender_input, labels) plus optional (receiver_input)
            sender_input = move_to(batch[0], device)
            receiver_input = None if len(batch) == 2 else move_to(
                batch[2], device)

            message = game.sender(sender_input)

            # Under GS, the only output is a message; under Reinforce, two additional tensors are returned.
            # We don't need them.
            if not gs:
                message = message[0]

            output = game.receiver(message, receiver_input)
            if not gs:
                output = output[0]

            if batch[1] is not None:
                labels.extend(batch[1])

            if isinstance(sender_input, list) or isinstance(
                    sender_input, tuple):
                sender_inputs.extend(zip(*sender_input))
            else:
                sender_inputs.extend(sender_input)

            if receiver_input is not None:
                receiver_inputs.extend(receiver_input)

            if gs:
                message = message.argmax(
                    dim=-1)  # actual symbols instead of one-hot encoded

            if not variable_length:
                messages.extend(message)
                receiver_outputs.extend(output)
            else:
                # A trickier part is to handle EOS in the messages. It also might happen that not every message has EOS.
                # We cut messages at EOS if it is present or return the entire message otherwise. Note, EOS id is always
                # set to 0.

                for i in range(message.size(0)):
                    eos_positions = (message[i, :] == 0).nonzero()
                    message_end = (eos_positions[0].item()
                                   if eos_positions.size(0) > 0 else -1)
                    assert message_end == -1 or message[i, message_end] == 0
                    if message_end < 0:
                        messages.append(message[i, :])
                    else:
                        messages.append(message[i, :message_end + 1])

                    if gs:
                        receiver_outputs.append(output[i, message_end, ...])
                    else:
                        receiver_outputs.append(output[i, ...])

    game.train(mode=train_state)

    return sender_inputs, messages, receiver_inputs, receiver_outputs, labels
コード例 #8
0
ファイル: train.py プロジェクト: DebasishMaji/EGG
        raise NotImplementedError(f'Unknown training mode, {opts.mode}')

    optimizer = torch.optim.Adam([
        {'params': game.sender.parameters(), 'lr': opts.sender_lr},
        {'params': game.receiver.parameters(), 'lr': opts.receiver_lr}
    ])

    trainer = core.Trainer(game=game, optimizer=optimizer,
                           train_data=train_data, validation_data=validation_data, epoch_callback=callback, as_json=opts.output_json)
    trainer.train(n_epochs=opts.n_epochs)

    if opts.evaluate:
        is_gs = opts.mode == 'gs'
        sender_inputs, messages, receiver_inputs, receiver_outputs, labels = core.dump_sender_receiver(game, test_data, is_gs, variable_length=True, device=device)

        receiver_outputs = move_to(receiver_outputs, device)
        labels = move_to(labels, device)

        receiver_outputs = torch.stack(receiver_outputs)
        labels = torch.stack(labels)

        tensor_accuracy = receiver_outputs.argmax(dim=1) == labels
        accuracy = torch.mean(tensor_accuracy.float()).item()

        unique_dict = {}

        for elem in sender_inputs:
            target = ""
            for dim in elem:
                target += f'{str(int(dim.item()))}-'
            target = target[:-1]