Пример #1
0
    def test_trainer(self):
        batch_size = 3
        length = 5
        descriptors = torch.FloatTensor(self.config.nclasses,
                                        self.config.descriptor_dim).normal_()
        sender = Sender(self.config)
        sender.eval()
        receiver = Receiver(self.config)
        receiver.eval()
        exchange_model = ExchangeModel(self.config)
        baseline_sender = Baseline(self.config, 'sender')
        baseline_receiver = Baseline(self.config, 'receiver')
        exchange = Exchange(exchange_model, sender, receiver, baseline_sender,
                            baseline_receiver, descriptors)
        trainer = Trainer(exchange)

        image = torch.FloatTensor(batch_size, self.config.image_in).normal_()
        target_dist = F.softmax(torch.FloatTensor(
            batch_size, self.config.nclasses).normal_(),
                                dim=1)
        target = target_dist.argmax(dim=1)
        trainer_loss = trainer.run_step(image, target)

        self.assertEqual(trainer_loss.sender_message_loss.numel(), 1)
        self.assertEqual(trainer_loss.receiver_message_loss.numel(), 1)
        self.assertEqual(trainer_loss.stop_loss.numel(), 1)
        self.assertEqual(trainer_loss.baseline_loss_sender.numel(), 1)
        self.assertEqual(trainer_loss.baseline_loss_receiver.numel(), 1)
        self.assertEqual(trainer_loss.xent_loss.numel(), 1)
Пример #2
0
    def test_reweight_descriptors(self):
        receiver = Receiver(self.config)

        batch_size = 2
        scores = F.softmax(torch.FloatTensor(batch_size,
                                             self.config.nclasses).normal_(),
                           dim=1)
        descriptors = torch.FloatTensor(self.config.nclasses,
                                        self.config.descriptor_dim).normal_()
        influence = receiver.reweight_descriptors(scores, descriptors)

        self.assertEqual(influence.size(),
                         (batch_size, self.config.descriptor_dim))
Пример #3
0
    def test_forward_eval(self):
        batch_size = 3
        receiver = Receiver(self.config)
        receiver.eval()

        msg = torch.FloatTensor(batch_size, self.config.message_in).normal_()
        state = torch.FloatTensor(batch_size,
                                  self.config.receiver_hidden_dim).fill_(0)
        descriptors = torch.FloatTensor(self.config.nclasses,
                                        self.config.descriptor_dim).normal_()
        (stop_bit, stop_dist), (message, message_dist), y, _ = receiver(
            msg, state, None, descriptors)

        check_bernoulli_dist(self, stop_dist)
        check_bernoulli_out(self, stop_bit)
        check_bernoulli_dist(self, message_dist)
        check_bernoulli_out(self, message)
        self.assertEqual(y.size(), (batch_size, self.config.nclasses))
Пример #4
0
    def test_build_pairs(self):
        receiver = Receiver(self.config)

        state = torch.FloatTensor(3, 2)
        state[0] = 0
        state[1] = 1
        state[2] = 2
        descs = torch.FloatTensor(2, 4)
        descs[0] = 5
        descs[1] = 6

        pairs = receiver.build_state_descriptor_pairs(state, descs)

        self.assertEqual(pairs.size(), (6, 6))
        self.assertTrue(pairs[[0, 1], :2].eq(0).all())
        self.assertTrue(pairs[[2, 3], :2].eq(1).all())
        self.assertTrue(pairs[[4, 5], :2].eq(2).all())
        self.assertTrue(pairs[[0, 2, 4], 2:].eq(5).all())
        self.assertTrue(pairs[[1, 3, 5], 2:].eq(6).all())
Пример #5
0
def train(run_n, threshold, n_episodes, df):
    proposer = Sender(idx=0, lr=lr, n_actions=n_actions, epsilon=epsilon)
    receiver = Receiver(idx=1,
                        lr=lr,
                        threshold=threshold,
                        n_states=n_actions,
                        epsilon=epsilon)

    for ep in range(1, n_episodes + 1):

        prop_idx = proposer.epsilon_greedy_proposal()
        prop = proposer.act(prop_idx)
        resp = receiver.threshold_recv(prop)

        rew = get_reward(resp, prop)
        proposer.TD_prop(prop_idx, rew)

        if (ep % ep_step == 0):
            df.loc[int(run_n * num_episodes / ep_step +
                       ep / ep_step)] = [run_n, ep, threshold, epsilon, rew
                                         ] + list(proposer.q_prop)
def main(n_episodes):

    df = pd.DataFrame(columns = ['run', 'episode', 'threshold', 'epsilon', 'rew_proposer', 'rew_receiver', 'train?'] \
        + ["prop_" + str(i/n_actions) for i in range(n_actions)] +
          ["recv_" + str(i/n_actions) for i in range(n_actions)] )

    for run, threshold in enumerate(thresholds):

        proposer = Sender(idx=0, lr=lr, n_actions=n_actions, epsilon=epsilon)
        receiver = Receiver(idx=1,
                            lr=lr,
                            threshold=threshold,
                            n_states=n_actions,
                            epsilon=epsilon)

        print("Run number=", run, " with threshold=", threshold)
        print("Train")
        train(proposer, receiver, run, n_episodes, df)
        print("Eval")
        eval(proposer, receiver, run_n=1, n_episodes_eval=0, df=df)

    return df
Пример #7
0
def main():
    train_loader, val_loader, input_size = get_data(100)

    show_batch(train_loader)

    # define agents
    hidden_size = 32
    n_actions = 2
    vocab_size = 11

    # train hyperparams
    num_episodes = 100
    lr = 0.1

    send = Sender(in_size=input_size, hidden_size=hidden_size, 
            vocab_len=vocab_size, lr=lr)

    recv = Receiver(in_size=input_size, vocabulary_size_sender = vocab_size, 
        hidden_size=hidden_size, n_actions = n_actions, lr=lr)


    def batch(send, recv, images_batch, labels_batch, send_opt=None, recv_opt=None):

        imgsa_s, imgsb_s, imgsa_r, imgsb_r, targets, _ = get_images(images_batch, labels_batch)

        probs_s, message, logprobs_s, entropy_s = send.model(imgsa_s, imgsb_s)
        probs_r, actions, logprobs_r, entropy_r = recv.model(imgsa_r, imgsb_r, message.detach())

        error = reward(actions, targets) #torch.abs(act - targets) il - e` gia nell'update mi pare
        acc = accuracy(actions, targets) #torch.mean(error.detach().double())

        send_loss = send.loss(error, logprobs_s, entropy_s)
        recv_loss = recv.loss(error, logprobs_r, entropy_r)

        if send_opt is not None:

            # SENDER LOSS
            send_opt.zero_grad()
            send_loss.backward()
            send_opt.step()

        if recv_opt is not None:

            # RECEIVER LOSS
            recv_opt.zero_grad()
            recv_loss.backward()
            recv_opt.step()

        return error, send_loss, recv_loss, len(imgsa_s), acc

    # UPLOAD MODELS

    #send.model.load_state_dict(torch.load('sender_model_mnist.pth'))
    #recv.model.load_state_dict(torch.load('receiver_model_mnist.pth'))

    send_opt = Adam(send.model.parameters(), lr=lr)
    recv_opt = Adam(recv.model.parameters(), lr=lr)
    print("lr=", lr)

    #TRAIN LOOP

    train_send_losses, train_recv_losses, val_send_losses, val_recv_losses, val_accuracy = [], [], [], [], []

    for ep in range(num_episodes):
        print("episode=", ep)

        # TRAIN STEP
        print("train")
        for imgs, labs in train_loader:

            train_error, train_send_loss, train_recv_loss, _, train_acc = batch(send, recv, imgs, labs, send_opt, recv_opt)

        print("evaluation")
        # EVALUATION STEP
        with torch.no_grad():

            results = [ batch(send, recv, imgs, labs) for imgs, labs in val_loader ]
            
        val_error, val_send_loss, val_recv_loss, nums, val_acc = zip(*results)

        total = np.sum(nums)
        send_train_avg_loss = np.sum(np.multiply(train_send_loss.detach().numpy(), nums))/total
        recv_train_avg_loss = np.sum(np.multiply(train_recv_loss.detach().numpy(), nums))/total
        train_send_losses.append(send_train_avg_loss)
        train_recv_losses.append(recv_train_avg_loss)

        send_val_avg_loss = np.sum(np.multiply(val_send_loss, nums))/total
        recv_val_avg_loss = np.sum(np.multiply(val_recv_loss, nums))/total
        val_send_losses.append(send_val_avg_loss)
        val_recv_losses.append(recv_val_avg_loss)
            
        val_avg_accuracy = np.sum(np.multiply(val_acc, nums))/total
        val_accuracy.append(val_avg_accuracy)

        print("sender train loss", send_train_avg_loss)
        print("receiver train loss", recv_train_avg_loss)
        print("sender val loss", send_val_avg_loss)
        print("receiver val loss", recv_val_avg_loss)
        print("accuracy", val_avg_accuracy)
        print("\n")

    torch.save(send.model.state_dict(), 'sender_model_cifar_01.pth')
    torch.save(recv.model.state_dict(), 'receiver_model_cifar_01.pth')
Пример #8
0
def run():
    # Get Description Vectors

    ## Training
    descr_train, word_dict_train, dict_size_train, label_id_to_idx_train, idx_to_label_train = read_data(
        FLAGS.descr_train)

    def map_labels_train(x):
        return label_id_to_idx_train.get(x)

    word_dict_train = embed(word_dict_train, FLAGS.word_embedding_path)
    descr_train = cbow(descr_train, word_dict_train)
    desc_train = torch.cat(
        [descr_train[i]["cbow"].view(1, -1) for i in descr_train.keys()], 0)
    desc_train_set = torch.cat([
        descr_train[i]["set"].view(-1, FLAGS.word_embedding_dim)
        for i in descr_train.keys()
    ], 0)
    desc_train_set_lens = [
        len(descr_train[i]["desc"]) for i in descr_train.keys()
    ]

    ## Development
    descr_dev, word_dict_dev, dict_size_dev, label_id_to_idx_dev, idx_to_label_dev = read_data(
        FLAGS.descr_dev)

    def map_labels_dev(x):
        return label_id_to_idx_dev.get(x)

    word_dict_dev = embed(word_dict_dev, FLAGS.word_embedding_path)
    descr_dev = cbow(descr_dev, word_dict_dev)
    desc_dev = torch.cat(
        [descr_dev[i]["cbow"].view(1, -1) for i in descr_dev.keys()], 0)
    desc_dev_set = torch.cat([
        descr_dev[i]["set"].view(-1, FLAGS.word_embedding_dim)
        for i in descr_dev.keys()
    ], 0)
    desc_dev_set_lens = [len(descr_dev[i]["desc"]) for i in descr_dev.keys()]

    desc_dev_dict = dict(desc=desc_dev,
                         desc_set=desc_dev_set,
                         desc_set_lens=desc_dev_set_lens)

    # Initialize Models
    config = AgentConfig()
    exchange_model = ExchangeModel(config)
    sender = Sender(config)
    receiver = Receiver(config)
    baseline_sender = Baseline(config, 'sender')
    baseline_receiver = Baseline(config, 'receiver')
    exchange = Exchange(exchange_model, sender, receiver, baseline_sender,
                        baseline_receiver, desc_train)
    trainer = Trainer(exchange)

    # Initialize Optimizer
    optimizer = optim.RMSprop(exchange.parameters(), lr=FLAGS.learning_rate)

    # Static Variables
    img_feat = "avgpool_512"
    topk = 5

    accs = []

    # Run Epochs
    for epoch in range(FLAGS.max_epoch):
        source = "directory"
        path = FLAGS.train_data
        loader_config = DirectoryLoaderConfig.build_with("resnet18")
        loader_config.map_labels = map_labels_train
        loader_config.batch_size = FLAGS.batch_size
        loader_config.shuffle = True
        loader_config.cuda = FLAGS.cuda

        dataloader = DataLoader.build_with(path, source,
                                           loader_config).iterator()

        for i_batch, batch in enumerate(dataloader):

            data = batch[img_feat]
            target = batch["target"]
            trainer_loss = trainer.run_step(data, target)
            loss = trainer.calculate_loss(trainer_loss)

            # Update Parameters
            optimizer.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm(exchange.parameters(), max_norm=1.)
            optimizer.step()

            y = trainer_loss.y
            topk_indices = y.sort()[1][:, -topk:]
            target_broadcast = target.view(-1,
                                           1).expand(FLAGS.batch_size, topk)
            accuracy = (topk_indices
                        == target_broadcast).sum().float() / float(
                            FLAGS.batch_size)
            accs.append(accuracy)
            mean_acc = sum(accs) / len(accs)

            print("Epoch = {}; Batch = {}; Accuracy = {}".format(
                epoch, i_batch, mean_acc))

            if len(accs) > 5:
                accs.pop(0)