Example #1
0
class Hidden:
    def __init__(self, configuration: HiDDenConfiguration,
                 device: torch.device, noiser: Noiser, tb_logger):
        """
        :param configuration: Configuration for the net, such as the size of the input image, number of channels in the intermediate layers, etc.
        :param device: torch.device object, CPU or GPU
        :param noiser: Object representing stacked noise layers.
        :param tb_logger: Optional TensorboardX logger object, if specified -- enables Tensorboard logging
        """
        super(Hidden, self).__init__()

        self.encoder_decoder = EncoderDecoder(configuration, noiser).to(device)
        self.optimizer_enc_dec = torch.optim.Adam(
            self.encoder_decoder.parameters())

        self.discriminator = Discriminator(configuration).to(device)
        self.optimizer_discrim = torch.optim.Adam(
            self.discriminator.parameters())

        if configuration.use_vgg:
            self.vgg_loss = VGGLoss(3, 1, False)
            self.vgg_loss.to(device)
        else:
            self.vgg_loss = None

        self.config = configuration
        self.device = device

        self.bce_with_logits_loss = nn.BCEWithLogitsLoss()
        self.mse_loss = nn.MSELoss()

        # Defined the labels used for training the discriminator/adversarial loss
        self.cover_label = 1
        self.encoded_label = 0

        self.tb_logger = tb_logger
        if tb_logger is not None:
            from tensorboard_logger import TensorBoardLogger
            encoder_final = self.encoder_decoder.encoder._modules[
                'final_layer']
            encoder_final.weight.register_hook(
                tb_logger.grad_hook_by_name('grads/encoder_out'))
            decoder_final = self.encoder_decoder.decoder._modules['linear']
            decoder_final.weight.register_hook(
                tb_logger.grad_hook_by_name('grads/decoder_out'))
            discrim_final = self.discriminator._modules['linear']
            discrim_final.weight.register_hook(
                tb_logger.grad_hook_by_name('grads/discrim_out'))

    def train_on_batch(self, batch: list):
        """
        Trains the network on a single batch consisting of images and messages
        :param batch: batch of training data, in the form [images, messages]
        :return: dictionary of error metrics from Encoder, Decoder, and Discriminator on the current batch
        """
        images, messages = batch
        batch_size = images.shape[0]
        with torch.enable_grad():
            # ---------------- Train the discriminator -----------------------------
            self.optimizer_discrim.zero_grad()
            # train on cover
            d_target_label_cover = torch.full((batch_size, 1),
                                              self.cover_label,
                                              device=self.device)
            d_on_cover = self.discriminator(images)
            d_loss_on_cover = self.bce_with_logits_loss(
                d_on_cover, d_target_label_cover)
            d_loss_on_cover.backward()

            # train on fake
            encoded_images, noised_images, decoded_messages = self.encoder_decoder(
                images, messages)
            d_target_label_encoded = torch.full((batch_size, 1),
                                                self.encoded_label,
                                                device=self.device)
            d_on_encoded = self.discriminator(encoded_images.detach())
            d_loss_on_encoded = self.bce_with_logits_loss(
                d_on_encoded, d_target_label_encoded)
            d_loss_on_encoded.backward()
            self.optimizer_discrim.step()

            # --------------Train the generator (encoder-decoder) ---------------------
            self.optimizer_enc_dec.zero_grad()
            # target label for encoded images should be 'cover', because we want to fool the discriminator
            g_target_label_encoded = torch.full((batch_size, 1),
                                                self.cover_label,
                                                device=self.device)
            d_on_encoded_for_enc = self.discriminator(encoded_images)
            g_loss_adv = self.bce_with_logits_loss(d_on_encoded_for_enc,
                                                   g_target_label_encoded)

            if self.vgg_loss == None:
                g_loss_enc = self.mse_loss(encoded_images, images)
            else:
                vgg_on_cov = self.vgg_loss(images)
                vgg_on_enc = self.vgg_loss(encoded_images)
                g_loss_enc = self.mse_loss(vgg_on_cov, vgg_on_enc)

            g_loss_dec = self.mse_loss(decoded_messages, messages)


            g_loss = self.config.adversarial_loss * g_loss_adv + self.config.encoder_loss * g_loss_enc \
                     + self.config.decoder_loss * g_loss_dec
            g_loss.backward()
            self.optimizer_enc_dec.step()

        decoded_rounded = decoded_messages.detach().cpu().numpy().round().clip(
            0, 1)
        bitwise_avg_err = np.sum(
            np.abs(decoded_rounded - messages.detach().cpu().numpy())) / (
                batch_size * messages.shape[1])

        losses = {
            'loss           ': g_loss.item(),
            'encoder_mse    ': g_loss_enc.item(),
            'dec_mse        ': g_loss_dec.item(),
            'bitwise-error  ': bitwise_avg_err,
            'adversarial_bce': g_loss_adv.item(),
            'discr_cover_bce': d_loss_on_cover.item(),
            'discr_encod_bce': d_loss_on_encoded.item()
        }
        return losses, (encoded_images, noised_images, decoded_messages)

    def validate_on_batch(self, batch: list):
        """
        Runs validation on a single batch of data consisting of images and messages
        :param batch: batch of validation data, in form [images, messages]
        :return: dictionary of error metrics from Encoder, Decoder, and Discriminator on the current batch
        """

        # if TensorboardX logging is enabled, save some of the tensors.
        if self.tb_logger is not None:
            encoder_final = self.encoder_decoder.encoder._modules[
                'final_layer']
            self.tb_logger.add_tensor('weights/encoder_out',
                                      encoder_final.weight)
            decoder_final = self.encoder_decoder.decoder._modules['linear']
            self.tb_logger.add_tensor('weights/decoder_out',
                                      decoder_final.weight)
            discrim_final = self.discriminator._modules['linear']
            self.tb_logger.add_tensor('weights/discrim_out',
                                      discrim_final.weight)

        images, messages = batch
        batch_size = images.shape[0]

        with torch.no_grad():
            d_on_cover = self.discriminator(images)
            d_target_label_cover = torch.full((batch_size, 1),
                                              self.cover_label,
                                              device=self.device)
            d_on_cover = self.discriminator(images)
            d_loss_on_cover = self.bce_with_logits_loss(
                d_on_cover, d_target_label_cover)

            encoded_images, noised_images, decoded_messages = self.encoder_decoder(
                images, messages)
            d_target_label_encoded = torch.full((batch_size, 1),
                                                self.encoded_label,
                                                device=self.device)
            d_on_encoded = self.discriminator(encoded_images)
            d_loss_on_encoded = self.bce_with_logits_loss(
                d_on_encoded, d_target_label_encoded)

            g_target_label_encoded = torch.full((batch_size, 1),
                                                self.cover_label,
                                                device=self.device)
            d_on_encoded_for_enc = self.discriminator(encoded_images)
            g_loss_adv = self.bce_with_logits_loss(d_on_encoded_for_enc,
                                                   g_target_label_encoded)

            if self.vgg_loss == None:
                g_loss_enc = self.mse_loss(encoded_images, images)
            else:
                vgg_on_cov = self.vgg_loss(images)
                vgg_on_enc = self.vgg_loss(encoded_images)
                g_loss_enc = self.mse_loss(vgg_on_cov, vgg_on_enc)

            g_loss_dec = self.mse_loss(decoded_messages, messages)
            g_loss = self.config.adversarial_loss * g_loss_adv + self.config.encoder_loss * g_loss_enc \
                     + self.config.decoder_loss * g_loss_dec

        decoded_rounded = decoded_messages.detach().cpu().numpy().round().clip(
            0, 1)
        bitwise_avg_err = np.sum(
            np.abs(decoded_rounded - messages.detach().cpu().numpy())) / (
                batch_size * messages.shape[1])

        losses = {
            'loss           ': g_loss.item(),
            'encoder_mse    ': g_loss_enc.item(),
            'dec_mse        ': g_loss_dec.item(),
            'bitwise-error  ': bitwise_avg_err,
            'adversarial_bce': g_loss_adv.item(),
            'discr_cover_bce': d_loss_on_cover.item(),
            'discr_encod_bce': d_loss_on_encoded.item()
        }
        return losses, (encoded_images, noised_images, decoded_messages)

    def to_stirng(self):
        return '{}\n{}'.format(str(self.encoder_decoder),
                               str(self.discriminator))
Example #2
0
class HiDDen(object):
    def __init__(self, config: HiDDenConfiguration, device: torch.device):
        self.enc_dec = EncoderDecoder(config).to(device)
        self.discr = Discriminator(config).to(device)
        self.opt_enc_dec = torch.optim.Adam(self.enc_dec.parameters())
        self.opt_discr = torch.optim.Adam(self.discr.parameters())

        self.config = config
        self.device = device
        self.bce_with_logits_loss = nn.BCEWithLogitsLoss().to(device)
        self.mse_loss = nn.MSELoss().to(device)

        self.cover_label = 1
        self.encod_label = 0

    def train_on_batch(self, batch: list):
        '''
        Trains the network on a single batch consistring images and messages
        '''
        images, messages = batch
        batch_size = images.shape[0]
        self.enc_dec.train()
        self.discr.train()

        with torch.enable_grad():
            # ---------- Train the discriminator----------
            self.opt_discr.zero_grad()

            # train on cover
            d_target_label_cover = torch.full((batch_size, 1),
                                              self.cover_label,
                                              device=self.device)
            d_target_label_encoded = torch.full((batch_size, 1),
                                                self.encod_label,
                                                device=self.device)
            g_target_label_encoded = torch.full((batch_size, 1),
                                                self.cover_label,
                                                device=self.device)

            d_on_cover = self.discr(images)
            d_loss_on_cover = self.bce_with_logits_loss(
                d_on_cover, d_target_label_cover)
            d_loss_on_cover.backward()

            # train on fake
            encoded_images, decoded_messages = self.enc_dec(images, messages)
            d_on_encoded = self.discr(encoded_images.detach())
            d_loss_on_encod = self.bce_with_logits_loss(
                d_on_encoded, d_target_label_encoded)
            d_loss_on_encod.backward()
            self.opt_discr.step()

            #---------- Train the generator----------
            self.opt_enc_dec.zero_grad()

            d_on_encoded_for_enc = self.discr(encoded_images)
            g_loss_adv = self.bce_with_logits_loss(d_on_encoded_for_enc,
                                                   g_target_label_encoded)
            g_loss_enc = self.mse_loss(encoded_images, images)
            g_loss_dec = self.mse_loss(decoded_messages, messages)

            g_loss = self.config.adversarial_loss * g_loss_adv \
                    + self.config.encoder_loss * g_loss_enc \
                    + self.config.decoder_loss * g_loss_dec
            g_loss.backward()
            self.opt_enc_dec.step()

        decoded_rounded = decoded_messages.detach().cpu().numpy().round().clip(
            0, 1)
        bitwise_err = np.sum(np.abs(decoded_rounded - messages.detach().cpu().numpy())) \
                      / (batch_size * messages.shape[1])

        losses = {
            'loss': g_loss.item(),
            'encoder_mse': g_loss_enc.item(),
            'decoder_mse': g_loss_dec.item(),
            'bitwise-error': bitwise_err,
            'adversarial_bce': g_loss_adv.item(),
            'discr_cover_bce': d_loss_on_cover.item(),
            'discr_encod_bce': d_loss_on_encod.item()
        }

        return losses, (encoded_images, decoded_messages)

    def validate_on_batch(self, batch: list):
        '''Run validation on a batch consist of [images, messages]'''
        images, messages = batch
        batch_size = images.shape[0]

        self.enc_dec.eval()
        self.discr.eval()

        with torch.no_grad():
            d_target_label_cover = torch.full((batch_size, 1),
                                              self.cover_label,
                                              device=self.device)
            d_target_label_encoded = torch.full((batch_size, 1),
                                                self.encod_label,
                                                device=self.device)
            g_target_label_encoded = torch.full((batch_size, 1),
                                                self.cover_label,
                                                device=self.device)

            d_on_cover = self.discr(images)
            d_loss_on_cover = self.bce_with_logits_loss(
                d_on_cover, d_target_label_cover)

            encoded_images, decoded_messages = self.enc_dec(images, messages)
            d_on_encoded = self.discr(encoded_images)
            d_loss_on_encod = self.bce_with_logits_loss(
                d_on_encoded, d_target_label_encoded)

            d_on_encoded_for_enc = self.discr(encoded_images)
            g_loss_adv = self.bce_with_logits_loss(d_on_encoded_for_enc,
                                                   g_target_label_encoded)
            g_loss_enc = self.mse_loss(encoded_images, images)
            g_loss_dec = self.mse_loss(decoded_messages, messages)

            g_loss = self.config.adversarial_loss * g_loss_adv \
                    + self.config.encoder_loss * g_loss_enc \
                    + self.config.decoder_loss * g_loss_dec

        decoded_rounded = decoded_messages.detach().cpu().numpy().round().clip(
            0, 1)
        bitwise_err = np.sum(np.abs(decoded_rounded - messages.detach().cpu().numpy()))\
                     / (batch_size * messages.shape[1])

        losses = {
            'loss': g_loss.item(),
            'encoder_mse': g_loss_enc.item(),
            'decoder_mse': g_loss_dec.item(),
            'bitwise-err': bitwise_err,
            'adversarial_bce': g_loss_adv.item(),
            'discr_cover_bce': d_loss_on_cover.item(),
            'discr_enced_bce': d_loss_on_encod.item()
        }

        return losses, (encoded_images, decoded_messages)

    def to_stirng(self):
        return f'{str(self.enc_dec)}\n{str(self.discr)}'
Example #3
0
def train(encoder_decoder: EncoderDecoder, train_data_loader: DataLoader,
          model_name, val_data_loader: DataLoader, keep_prob,
          teacher_forcing_schedule, lr, max_length, use_decay, data_path):

    global_step = 0
    loss_function = torch.nn.NLLLoss(ignore_index=0)
    optimizer = optim.Adam(encoder_decoder.parameters(), lr=lr)
    model_path = './saved/' + model_name + '/'

    if (use_decay == False):
        gamma = 1.0
    else:
        gamma = 0.5
    scheduler = StepLR(optimizer, step_size=1, gamma=gamma)

    #val_loss, val_bleu_score = evaluate(encoder_decoder, val_data_loader)

    best_bleu = 0.0

    for epoch, teacher_forcing in enumerate(teacher_forcing_schedule):
        #scheduler.step()
        print('epoch %i' % (epoch), flush=True)
        print('lr: ' + str(scheduler.get_lr()))

        for batch_idx, (input_idxs, target_idxs, input_tokens,
                        target_tokens) in enumerate(tqdm(train_data_loader)):
            # input_idxs and target_idxs have dim (batch_size x max_len)
            # they are NOT sorted by length
            '''
            print(input_idxs[0])
            print(input_tokens[0])
            print(target_idxs[0])
            print(target_tokens[0])
            '''
            lengths = (input_idxs != 0).long().sum(dim=1)
            sorted_lengths, order = torch.sort(lengths, descending=True)

            input_variable = Variable(input_idxs[order, :][:, :max(lengths)])
            target_variable = Variable(target_idxs[order, :])

            optimizer.zero_grad()
            output_log_probs, output_seqs = encoder_decoder(
                input_variable,
                list(sorted_lengths),
                targets=target_variable,
                keep_prob=keep_prob,
                teacher_forcing=teacher_forcing)
            batch_size = input_variable.shape[0]

            flattened_outputs = output_log_probs.view(batch_size * max_length,
                                                      -1)

            batch_loss = loss_function(flattened_outputs,
                                       target_variable.contiguous().view(-1))

            batch_loss.backward()
            optimizer.step()
            batch_outputs = trim_seqs(output_seqs)

            batch_targets = [[list(seq[seq > 0])]
                             for seq in list(to_np(target_variable))]

            #batch_bleu_score = corpus_bleu(batch_targets, batch_outputs, smoothing_function=SmoothingFunction().method2)
            batch_bleu_score = corpus_bleu(batch_targets, batch_outputs)
            '''
            if global_step < 10 or (global_step % 10 == 0 and global_step < 100) or (global_step % 100 == 0 and epoch < 2):
                input_string = "Amy, Please schedule a meeting with Marcos on Tuesday April 3rd. Adam Kleczewski"
                output_string = encoder_decoder.get_response(input_string)
                writer.add_text('schedule', output_string, global_step=global_step)

                input_string = "Amy, Please cancel this meeting. Adam Kleczewski"
                output_string = encoder_decoder.get_response(input_string)
                writer.add_text('cancel', output_string, global_step=global_step)
            '''

            if global_step % 100 == 0:

                writer.add_scalar('train_batch_loss', batch_loss, global_step)
                writer.add_scalar('train_batch_bleu_score', batch_bleu_score,
                                  global_step)

                for tag, value in encoder_decoder.named_parameters():
                    tag = tag.replace('.', '/')
                    writer.add_histogram('weights/' + tag,
                                         value,
                                         global_step,
                                         bins='doane')
                    writer.add_histogram('grads/' + tag,
                                         to_np(value.grad),
                                         global_step,
                                         bins='doane')

            global_step += 1

            debug = False

            if (debug):
                if batch_idx == 5:
                    break

        val_loss, val_bleu_score = evaluate(encoder_decoder, val_data_loader)

        writer.add_scalar('val_loss', val_loss, global_step=global_step)
        writer.add_scalar('val_bleu_score',
                          val_bleu_score,
                          global_step=global_step)

        encoder_embeddings = encoder_decoder.encoder.embedding.weight.data
        encoder_vocab = encoder_decoder.lang.tok_to_idx.keys()
        writer.add_embedding(encoder_embeddings,
                             metadata=encoder_vocab,
                             global_step=0,
                             tag='encoder_embeddings')

        decoder_embeddings = encoder_decoder.decoder.embedding.weight.data
        decoder_vocab = encoder_decoder.lang.tok_to_idx.keys()
        writer.add_embedding(decoder_embeddings,
                             metadata=decoder_vocab,
                             global_step=0,
                             tag='decoder_embeddings')
        '''
        input_string = "Amy, Please schedule a meeting with Marcos on Tuesday April 3rd. Adam Kleczewski"
        output_string = encoder_decoder.get_response(input_string)
        writer.add_text('schedule', output_string, global_step=global_step)

        input_string = "Amy, Please cancel this meeting. Adam Kleczewski"
        output_string = encoder_decoder.get_response(input_string)
        writer.add_text('cancel', output_string, global_step=global_step)
        '''

        calc_bleu_score = get_bleu(encoder_decoder, data_path, None, 'dev')
        print('val loss: %.5f, val BLEU score: %.5f' %
              (val_loss, calc_bleu_score),
              flush=True)
        if (calc_bleu_score > best_bleu):
            print("Best BLEU score! Saving model...")
            best_bleu = calc_bleu_score
            torch.save(
                encoder_decoder, "%s%s_%i_%.3f.pt" %
                (model_path, model_name, epoch, calc_bleu_score))

        print('-' * 100, flush=True)

        scheduler.step()
Example #4
0
def train(encoder_decoder: EncoderDecoder, train_data_loader: DataLoader,
          model_name, val_data_loader: DataLoader, keep_prob,
          teacher_forcing_schedule, lr, max_length, device,
          test_data_loader: DataLoader):

    global_step = 0
    loss_function = torch.nn.NLLLoss(ignore_index=0)
    optimizer = optim.Adam(encoder_decoder.parameters(), lr=lr)
    model_path = './model/' + model_name + '/'
    trained_model = encoder_decoder

    for epoch, teacher_forcing in enumerate(teacher_forcing_schedule):
        print('epoch %i' % epoch, flush=True)
        correct_predictions = 0.0
        all_predictions = 0.0
        for batch_idx, (input_idxs, target_idxs, input_tokens,
                        target_tokens) in enumerate(tqdm(train_data_loader)):
            # Empty the cache at each batch
            torch.cuda.empty_cache()
            # input_idxs and target_idxs have dim (batch_size x max_len)
            # they are NOT sorted by length

            lengths = (input_idxs != 0).long().sum(dim=1)
            sorted_lengths, order = torch.sort(lengths, descending=True)

            input_variable = input_idxs[order, :][:, :max(lengths)]
            input_variable = input_variable.to(device)
            target_variable = target_idxs[order, :]
            target_variable = target_variable.to(device)

            optimizer.zero_grad()
            output_log_probs, output_seqs = encoder_decoder(
                input_variable,
                list(sorted_lengths),
                targets=target_variable,
                keep_prob=keep_prob,
                teacher_forcing=teacher_forcing)

            batch_size = input_variable.shape[0]

            output_sentences = output_seqs.squeeze(2)

            flattened_outputs = output_log_probs.view(batch_size * max_length,
                                                      -1)

            batch_loss = loss_function(flattened_outputs,
                                       target_variable.contiguous().view(-1))
            batch_outputs = trim_seqs(output_seqs)

            batch_inputs = [[list(seq[seq > 0])]
                            for seq in list(to_np(input_variable))]
            batch_targets = [[list(seq[seq > 0])]
                             for seq in list(to_np(target_variable))]

            for i in range(len(batch_outputs)):
                y_i = batch_outputs[i]
                tgt_i = batch_targets[i][0]

                if y_i == tgt_i:
                    correct_predictions += 1.0

                all_predictions += 1.0

            batch_loss.backward()
            optimizer.step()

            batch_bleu_score = corpus_bleu(
                batch_targets,
                batch_outputs,
                smoothing_function=SmoothingFunction().method1)

            if global_step % 100 == 0:

                writer.add_scalar('train_batch_loss', batch_loss, global_step)
                writer.add_scalar('train_batch_bleu_score', batch_bleu_score,
                                  global_step)

                for tag, value in encoder_decoder.named_parameters():
                    tag = tag.replace('.', '/')
                    writer.add_histogram('weights/' + tag,
                                         value,
                                         global_step,
                                         bins='doane')
                    writer.add_histogram('grads/' + tag,
                                         to_np(value.grad),
                                         global_step,
                                         bins='doane')

            global_step += 1

        encoder_embeddings = encoder_decoder.encoder.embedding.weight.data
        encoder_vocab = encoder_decoder.lang.tok_to_idx.keys()
        writer.add_embedding(encoder_embeddings,
                             metadata=encoder_vocab,
                             global_step=0,
                             tag='encoder_embeddings')

        decoder_embeddings = encoder_decoder.decoder.embedding.weight.data
        decoder_vocab = encoder_decoder.lang.tok_to_idx.keys()
        writer.add_embedding(decoder_embeddings,
                             metadata=decoder_vocab,
                             global_step=0,
                             tag='decoder_embeddings')

        print('training accuracy %.5f' %
              (100.0 * (correct_predictions / all_predictions)))
        torch.save(encoder_decoder,
                   "%s%s_%i.pt" % (model_path, model_name, epoch))
        trained_model = encoder_decoder

        print('-' * 100, flush=True)

    torch.save(encoder_decoder, "%s%s_final.pt" % (model_path, model_name))
    return trained_model