Esempio n. 1
0
    def train(epoch):
        random.seed(42)
        np.random.seed(42)  # important to have the same seed
                            # in order to make the same choices for weak supervision
                            # otherwise, we end up showing different examples over epochs
        vae.train()

        joint_loss_meter = AverageMeter()
        image_loss_meter = AverageMeter()
        text_loss_meter = AverageMeter()

        for batch_idx, (image, text) in enumerate(train_loader):
            if cuda:
                image, text = image.cuda(), text.cuda()
            image, text = Variable(image), Variable(text)
            optimizer.zero_grad()
            
            # depending on this flip, we either show it a full paired example or 
            # we show it single modalities (in which we cannot compute the full loss)
            flip = np.random.random()
            if flip < weak_perc:  # here we show a paired example
                recon_image_1, recon_text_1, mu_1, logvar_1 = vae(image, text)
                loss_1 = loss_function(mu_1, logvar_1, recon_image=recon_image_1, image=image, 
                                       recon_text=recon_text_1, text=text, kl_lambda=kl_lambda,
                                       lambda_xy=1., lambda_yx=1.)
                recon_image_2, recon_text_2, mu_2, logvar_2 = vae(image=image)
                loss_2 = loss_function(mu_2, logvar_2, recon_image=recon_image_2, image=image, 
                                       recon_text=recon_text_2, text=text, kl_lambda=kl_lambda,
                                       lambda_xy=1., lambda_yx=1.)
                recon_image_3, recon_text_3, mu_3, logvar_3 = vae(text=text)
                loss_3 = loss_function(mu_3, logvar_3, recon_image=recon_image_3, image=image, 
                                       recon_text=recon_text_3, text=text, kl_lambda=kl_lambda,
                                       lambda_xy=0., lambda_yx=1.)

                loss = loss_1 + loss_2 + loss_3
                joint_loss_meter.update(loss_1.data[0], len(image))
            
            else:  # here we show individual modalities
                recon_image_2, _, mu_2, logvar_2 = vae(image=image)
                loss_2 = loss_function(mu_2, logvar_2, recon_image=recon_image_2, image=image, 
                                       kl_lambda=kl_lambda, lambda_xy=1.)
                _, recon_text_3, mu_3, logvar_3 = vae(text=text)
                loss_3 = loss_function(mu_3, logvar_3, recon_text=recon_text_3, text=text, 
                                       kl_lambda=kl_lambda, lambda_yx=1.)
                loss = loss_2 + loss_3

            image_loss_meter.update(loss_2.data[0], len(image))
            text_loss_meter.update(loss_3.data[0], len(text))

            loss.backward()
            optimizer.step()

            if batch_idx % log_interval == 0:
                print('[Weak {:.0f}%] Train Epoch: {} [{}/{} ({:.0f}%)]\tJoint Loss: {:.6f}\tImage Loss: {:.6f}\tText Loss: {:.6f}'.format(
                    100. * weak_perc, epoch, batch_idx * len(image), len(train_loader.dataset),
                    100. * batch_idx / len(train_loader), joint_loss_meter.avg,
                    image_loss_meter.avg, text_loss_meter.avg))

        print('====> [Weak {:.0f}%] Epoch: {} Joint loss: {:.4f}\tImage loss: {:.4f}\tText loss: {:.4f}'.format(
            100. * weak_perc, epoch, joint_loss_meter.avg, image_loss_meter.avg, text_loss_meter.avg))
Esempio n. 2
0
    def test():
        vae.eval()
        test_joint_loss = 0
        test_image_loss = 0
        test_text_loss = 0

        for batch_idx, (image, text) in enumerate(test_loader):
            if cuda:
                image, text = image.cuda(), text.cuda()
            image, text = Variable(image), Variable(text)
            image = image.view(-1, 784)  # flatten image

            # in test i always care about the joint loss -- so we don't anneal
            # back joint examples as we do in train
            recon_image_1, recon_text_1, mu_1, logvar_1 = vae(image, text)
            recon_image_2, recon_text_2, mu_2, logvar_2 = vae(image=image)
            recon_image_3, recon_text_3, mu_3, logvar_3 = vae(text=text)

            loss_1 = loss_function(mu_1,
                                   logvar_1,
                                   recon_image=recon_image_1,
                                   image=image,
                                   recon_text=recon_text_1,
                                   text=text,
                                   lambda_xy=1.,
                                   lambda_yx=1.)
            loss_2 = loss_function(mu_2,
                                   logvar_2,
                                   recon_image=recon_image_2,
                                   image=image,
                                   recon_text=recon_text_2,
                                   text=text,
                                   lambda_xy=1.,
                                   lambda_yx=1.)
            loss_3 = loss_function(mu_3,
                                   logvar_3,
                                   recon_image=recon_image_3,
                                   image=image,
                                   recon_text=recon_text_3,
                                   text=text,
                                   lambda_xy=0.,
                                   lambda_yx=1.)

            test_joint_loss += loss_1.data[0]
            test_image_loss += loss_2.data[0]
            test_text_loss += loss_3.data[0]

        test_loss = test_joint_loss + test_image_loss + test_text_loss
        test_joint_loss /= len(test_loader)
        test_image_loss /= len(test_loader)
        test_text_loss /= len(test_loader)
        test_loss /= len(test_loader)

        print(
            '====> [Weak {:.0f}%] Test joint loss: {:.4f}\timage loss: {:.4f}\ttext loss:{:.4f}'
            .format(100. * weak_perc, test_joint_loss, test_image_loss,
                    test_text_loss))

        return test_loss, (test_joint_loss, test_image_loss, test_text_loss)
    def train(epoch):
        vae.train()
        loss_meter = AverageMeter()

        for batch_idx, (_, data) in enumerate(train_loader):
            data = Variable(data)
            if args.cuda:
                data = data.cuda()
            optimizer.zero_grad()
            recon_batch, mu, logvar = vae(data)
            loss = loss_function(mu,
                                 logvar,
                                 recon_text=recon_batch,
                                 text=data,
                                 kl_lambda=kl_lambda,
                                 lambda_yx=1.)
            loss.backward()
            loss_meter.update(loss.data[0], len(data))
            optimizer.step()
            if batch_idx % args.log_interval == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, batch_idx * len(data), len(train_loader.dataset),
                    100. * batch_idx / len(train_loader), loss_meter.avg))

        print('====> Epoch: {} Average loss: {:.4f}'.format(
            epoch, loss_meter.avg))
Esempio n. 4
0
    def train(epoch):
        print('Using KL Lambda: {}'.format(kl_lambda))
        vae.train()
        loss_meter = AverageMeter()

        for batch_idx, (data, _) in enumerate(train_loader):
            data = Variable(data)
            if args.cuda:
                data = data.cuda()
            optimizer.zero_grad()
            recon_batch, mu, logvar = vae(data)
            # watch out for logvar -- could explode if learning rate is too high.
            loss = loss_function(mu,
                                 logvar,
                                 recon_image=recon_batch,
                                 image=data,
                                 kl_lambda=kl_lambda,
                                 lambda_xy=1.)
            loss_meter.update(loss.data[0], len(data))
            loss.backward()
            optimizer.step()
            if batch_idx % args.log_interval == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, batch_idx * len(data), len(train_loader.dataset),
                    100. * batch_idx / len(train_loader), loss_meter.avg))

        print('====> Epoch: {} Average loss: {:.4f}'.format(
            epoch, loss_meter.avg))
Esempio n. 5
0
    def test():
        vae.eval()
        test_loss = 0
        for i, (_, data) in enumerate(test_loader):
            if args.cuda:
                data = data.cuda()
            data = Variable(data, volatile=True)
            recon_batch, mu, logvar = vae(data)
            test_loss += loss_function(mu, logvar, recon_text=recon_batch, text=data, 
                                       kl_lambda=kl_lambda, lambda_yx=1.).data[0]

        test_loss /= len(test_loader)
        print('====> Test set loss: {:.4f}'.format(test_loss))
        return test_loss
Esempio n. 6
0
def train_step(inp, tar):
    tar_inp = tar[:, :-1]
    tar_real = tar[:, 1:]

    enc_padding_mask, combined_mask, dec_padding_mask = create_masks(
        inp, tar_inp)

    with tf.GradientTape() as tape:
        predictions, _ = transformer(inp, tar_inp, True, enc_padding_mask,
                                     combined_mask, dec_padding_mask)
        loss = loss_function(tar_real, predictions, loss_object)

    gradients = tape.gradient(loss, transformer.trainable_variables)
    optimizer.apply_gradients(zip(gradients, transformer.trainable_variables))

    train_loss(loss)
    train_accuracy(tar_real, predictions)
Esempio n. 7
0
def train_step(input, target, encoded_hidden):
    """
     performs one training step (on batch)
    :param input:
    :param target:
    :param target_lang:
    :param encoded_hidden:
    :param optimizer:
    :param encoder:
    :param decoder:
    :param batch_size:
    :return:
    """

    loss = 0

    with tf.GradientTape() as tape:
        encoded_output, encoded_hidden = encoder(input, encoded_hidden)
        decoded_hidden = encoded_hidden
        print("decoded hidden")
        print(decoded_hidden.shape)
        print(" ")
        decoded_input = tf.expand_dims([target_language.word_index["<"]] * BATCH_SIZE, 1)
        print("decoded input shape")
        print(decoded_input.shape)
        print(" ")
        print(" ")

        # Teacher forcing - feeding the target as the next input
        for t in range(1, target.shape[1]):
            # passing enc_output to the decoder
            predicted, decoded_hidden, _ = decoder(decoded_input, decoded_hidden, encoded_output)

            loss += loss_function(target[:, t], predicted)
            # using teacher forcing
            decoded_input = tf.expand_dims(target[:, t], 1)

        batch_loss = (loss / int(target.shape[1]))
        variables = encoder.trainable_variables + decoder.trainable_variables
        gradients = tape.gradient(loss, variables)
        optimizer.apply_gradients(zip(gradients, variables))

        return batch_loss
Esempio n. 8
0
    for batch_idx, (image, attrs) in enumerate(loader):
        if args.cuda:
            image, attrs = image.cuda(), attrs.cuda()
        image = Variable(image, volatile=True)
        attrs = Variable(attrs, volatile=True)

        recon_image_1, recon_attrs_1, mu_1, logvar_1 = vae(image=image,
                                                           attrs=attrs)
        recon_image_2, recon_attrs_2, mu_2, logvar_2 = vae(image=image)
        recon_image_3, recon_attrs_3, mu_3, logvar_3 = vae(attrs=attrs)

        joint_loss += loss_function(mu_1,
                                    logvar_1,
                                    recon_x=recon_image_1,
                                    x=image,
                                    recon_y=recon_attrs_1,
                                    y=attrs,
                                    kl_lambda=1.,
                                    lambda_x=1.,
                                    lambda_y=1.).data[0]
        image_loss += loss_function(mu_2,
                                    logvar_2,
                                    recon_x=recon_image_2,
                                    x=image,
                                    recon_y=recon_attrs_2,
                                    y=attrs,
                                    kl_lambda=1.,
                                    lambda_x=1.,
                                    lambda_y=1.).data[0]
        attrs_loss += loss_function(mu_3,
                                    logvar_3,
Esempio n. 9
0
    def train(epoch):
        random.seed(42)
        np.random.seed(42)  # important to have the same seed
        # in order to make the same choices for weak supervision
        # otherwise, we end up showing different examples over epochs
        vae.train()

        joint_loss_meter = AverageMeter()
        image_loss_meter = AverageMeter()
        text_loss_meter = AverageMeter()

        for batch_idx, (image, text) in enumerate(train_loader):
            if cuda:
                image, text = image.cuda(), text.cuda()
            image, text = Variable(image), Variable(text)
            optimizer.zero_grad()

            recon_image_1, recon_text_1, mu_1, logvar_1 = vae(image, text)
            loss = loss_function(mu_1,
                                 logvar_1,
                                 recon_image=recon_image_1,
                                 image=image,
                                 recon_text=recon_text_1,
                                 text=text,
                                 kl_lambda=kl_lambda,
                                 lambda_xy=1.,
                                 lambda_yx=1.)
            joint_loss_meter.update(loss.data[0], len(image))

            # depending on this flip, we decide whether or not to show a modality
            # versus another one.
            flip = np.random.random()

            if flip < weak_perc_m1:
                recon_image_2, recon_text_2, mu_2, logvar_2 = vae(image=image)
                loss_2 = loss_function(mu_2,
                                       logvar_2,
                                       recon_image=recon_image_2,
                                       image=image,
                                       recon_text=recon_text_2,
                                       text=text,
                                       kl_lambda=kl_lambda,
                                       lambda_xy=1.,
                                       lambda_yx=1.)
                image_loss_meter.update(loss_2.data[0], len(image))
                loss += loss_2

            flip = np.random.random()
            if flip < weak_perc_m2:
                recon_image_3, recon_text_3, mu_3, logvar_3 = vae(text=text)
                loss_3 = loss_function(mu_3,
                                       logvar_3,
                                       recon_image=recon_image_3,
                                       image=image,
                                       recon_text=recon_text_3,
                                       text=text,
                                       kl_lambda=kl_lambda,
                                       lambda_xy=0.,
                                       lambda_yx=1.)
                text_loss_meter.update(loss_3.data[0], len(text))
                loss += loss_3

            loss.backward()
            optimizer.step()

            if batch_idx % log_interval == 0:
                print(
                    '[Weak (Image) {:.0f}% | Weak (Text) {:.0f}%] Train Epoch: {} [{}/{} ({:.0f}%)]\tJoint Loss: {:.6f}\tImage Loss: {:.6f}\tText Loss: {:.6f}'
                    .format(100. * weak_perc_m1, 100. * weak_perc_m2, epoch,
                            batch_idx * len(image), len(train_loader.dataset),
                            100. * batch_idx / len(train_loader),
                            joint_loss_meter.avg, image_loss_meter.avg,
                            text_loss_meter.avg))

        print(
            '====> [Weak (Image) {:.0f}% | Weak (Text) {:.0f}%] Epoch: {} Joint loss: {:.4f}\tImage loss: {:.4f}\tText loss: {:.4f}'
            .format(100. * weak_perc_m1, 100. * weak_perc_m2, epoch,
                    joint_loss_meter.avg, image_loss_meter.avg,
                    text_loss_meter.avg))