Exemplo n.º 1
0
def evaluate(vqa, data_loader, criterion, l2_criterion, args):
    """Calculates vqa average loss on data_loader.

    Args:
        vqa: visual question answering model.
        data_loader: Iterator for the data.
        criterion: The loss function used to evaluate the loss.
        l2_criterion: The loss function used to evaluate the l2 loss.
        args: ArgumentParser object.

    Returns:
        A float value of average loss.
    """
    vqa.eval()
    total_gen_loss = 0.0
    total_kl = 0.0
    total_recon_image_loss = 0.0
    total_recon_question_loss = 0.0
    total_z_t_kl = 0.0
    total_t_kl_loss = 0.0
    total_steps = len(data_loader)
    if args.eval_steps is not None:
        total_steps = min(len(data_loader), args.eval_steps)
    start_time = time.time()
    for iterations, (images, questions, answers,
                     aindices) in enumerate(data_loader):

        # Set mini-batch dataset
        if torch.cuda.is_available():
            images = images.cuda()
            questions = questions.cuda()
            answers = answers.cuda()
            aindices = aindices.cuda()
        qlengths = process_lengths(questions)
        qlengths.sort(reverse=True)

        # Forward, Backward and Optimize
        image_features = vqa.encode_images(images)
        question_features = vqa.encode_questions(questions, qlengths)
        mus, logvars = vqa.encode_into_z(image_features, question_features)
        zs = vqa.reparameterize(mus, logvars)
        (outputs, _, other) = vqa.decode_answers(image_features,
                                                 zs,
                                                 answers=answers,
                                                 teacher_forcing_ratio=1.0)

        # Reorder the questions based on length.
        answers = torch.index_select(answers, 0, aindices)

        # Ignoring the start token.
        answers = answers[:, 1:]
        alengths = process_lengths(answers)
        alengths.sort(reverse=True)

        # Convert the output from MAX_LEN list of (BATCH x VOCAB) ->
        # (BATCH x MAX_LEN x VOCAB).
        outputs = [o.unsqueeze(1) for o in outputs]
        outputs = torch.cat(outputs, dim=1)
        outputs = torch.index_select(outputs, 0, aindices)

        # Calculate the loss.
        targets = pack_padded_sequence(answers, alengths, batch_first=True)[0]
        outputs = pack_padded_sequence(outputs, alengths, batch_first=True)[0]
        gen_loss = criterion(outputs, targets)
        total_gen_loss += gen_loss.data.item()

        # Get KL loss if it exists.
        kl_loss = gaussian_KL_loss(mus, logvars)
        total_kl += kl_loss.item()

        # Reconstruction.
        if not args.no_image_recon or not args.no_question_recon:
            image_targets = image_features.detach()
            question_targets = question_features.detach()
            recon_image_features, recon_question_features = vqa.reconstruct_inputs(
                image_targets, question_targets)

            if not args.no_image_recon:
                recon_i_loss = l2_criterion(recon_image_features,
                                            image_targets)
                total_recon_image_loss += recon_i_loss.item()
            if not args.no_question_recon:
                recon_q_loss = l2_criterion(recon_question_features,
                                            question_targets)
                total_recon_question_loss += recon_q_loss.item()

        # Quit after eval_steps.
        if args.eval_steps is not None and iterations >= args.eval_steps:
            break

        # Print logs
        if iterations % args.log_step == 0:
            delta_time = time.time() - start_time
            start_time = time.time()
            logging.info('Time: %.4f, Step [%d/%d], gen loss: %.4f, '
                         'KL: %.4f, I-recon: %.4f, Q-recon: %.4f' %
                         (delta_time, iterations, total_steps, total_gen_loss /
                          (iterations + 1), total_kl /
                          (iterations + 1), total_recon_image_loss /
                          (iterations + 1), total_recon_question_loss /
                          (iterations + 1)))
    total_info_loss = total_recon_image_loss + total_recon_question_loss
    return total_gen_loss / (iterations + 1), total_info_loss / (iterations +
                                                                 1)
Exemplo n.º 2
0
def evaluate(vqg, data_loader, criterion, l2_criterion, args):
    vqg.eval()

    if (args.bayes):
        alpha = vqg.alpha

    total_gen_loss = 0.0
    total_kl = 0.0
    total_recon_image_loss = 0.0
    total_recon_category_loss = 0.0
    total_z_t_kl = 0.0
    total_t_kl_loss = 0.0
    regularisation_loss = 0.0
    c_loss = 0.0
    category_cycle_loss = 0.0

    total_steps = len(data_loader)
    if args.eval_steps is not None:
        total_steps = min(len(data_loader), args.eval_steps)
    start_time = time.time()
    for iterations, (images, questions, answers, categories,
                     qindices) in enumerate(data_loader):
        ''' remove answers from the dataloader later '''

        # Set mini-batch dataset
        if torch.cuda.is_available():
            images = images.cuda()
            questions = questions.cuda()
            answers = answers.cuda()
            categories = categories.cuda()
            qindices = qindices.cuda()
            if (args.bayes):
                alpha = vqg.alpha.cuda()

        # Forward, Backward and Optimize
        image_features = vqg.encode_images(images)
        category_features = vqg.encode_categories(categories)
        t_mus, t_logvars, ts = vqg.encode_into_t(image_features,
                                                 category_features)
        (outputs, _,
         other), pred_ques = vqg.decode_questions(image_features,
                                                  ts,
                                                  questions=questions,
                                                  teacher_forcing_ratio=1.0)

        # Reorder the questions based on length.
        questions = torch.index_select(questions, 0, qindices)
        total_loss = 0.0

        # Ignoring the start token.
        questions = questions[:, 1:]
        qlengths = process_lengths(questions)

        # Convert the output from MAX_LEN list of (BATCH x VOCAB) ->
        # (BATCH x MAX_LEN x VOCAB).
        outputs = [o.unsqueeze(1) for o in outputs]
        outputs = torch.cat(outputs, dim=1)
        outputs = torch.index_select(outputs, 0, qindices)

        if (args.step_two):
            category_cycle = vqg.encode_questions(pred_ques, qlengths)
            category_cycle_loss = criterion(category_cycle, categories)

            category_cycle_loss = category_cycle_loss.item()
            total_loss += args.lambda_c_cycle * category_cycle_loss

        # Calculate the loss.
        targets = pack_padded_sequence(questions, qlengths,
                                       batch_first=True)[0]
        outputs = pack_padded_sequence(outputs, qlengths, batch_first=True)[0]

        gen_loss = criterion(outputs, targets)
        total_gen_loss += gen_loss.data.item()

        # Get KL loss if it exists.
        if (args.bayes):
            regularisation_loss = l2_criterion(alpha.pow(-1),
                                               torch.ones_like(alpha))
            kl_loss = -0.5 * torch.sum(1 + t_logvars + alpha.pow(2).log() -
                                       alpha.pow(2) *
                                       (t_mus.pow(2) + t_logvars.exp()))
            total_kl += kl_loss.item() + regularisation_loss.item()
        else:
            kl_loss = gaussian_KL_loss(t_mus, t_logvars)
            total_kl += args.lambda_t * kl_loss
            kl_loss = kl_loss.item()

        # Reconstruction.
        if not args.no_image_recon or not args.no_category_space:
            image_targets = image_features.detach()
            category_targets = category_features.detach()

            recon_image_features, recon_category_features = vqg.reconstruct_inputs(
                image_targets, category_targets)

            if not args.no_image_recon:
                recon_i_loss = l2_criterion(recon_image_features,
                                            image_targets)
                total_recon_image_loss += recon_i_loss.item()

            if not args.no_category_space:
                recon_c_loss = l2_criterion(recon_category_features,
                                            category_targets)
                total_recon_category_loss += recon_c_loss.item()

        # Quit after eval_steps.
        if args.eval_steps is not None and iterations >= args.eval_steps:
            break

        # Print logs
        if iterations % args.log_step == 0:
            delta_time = time.time() - start_time
            start_time = time.time()
            logging.info(
                'Time: %.4f, Step [%d/%d], gen loss: %.4f, '
                'KL: %.4f, I-recon: %.4f, C-recon: %.4f, C-cycle: %.4f, Regularisation: %.4f'
                % (delta_time, iterations, total_steps, total_gen_loss /
                   (iterations + 1), total_kl /
                   (iterations + 1), total_recon_image_loss /
                   (iterations + 1), total_recon_category_loss /
                   (iterations + 1), category_cycle_loss /
                   (iterations + 1), regularisation_loss / (iterations + 1)))

    total_info_loss = total_recon_image_loss + total_recon_category_loss
    return total_gen_loss / (iterations + 1), total_info_loss / (iterations +
                                                                 1)
Exemplo n.º 3
0
def train(args):
    # Create model directory
    if not os.path.exists(args.model_path):
        os.makedirs(args.model_path)

    # Save the arguments.
    with open(os.path.join(args.model_path, 'args.json'), 'w') as args_file:
        json.dump(args.__dict__, args_file)

    # Config logging.
    log_format = '%(levelname)-8s %(message)s'
    logfile = os.path.join(args.model_path, 'train.log')
    logging.basicConfig(filename=logfile,
                        level=logging.INFO,
                        format=log_format)
    logging.getLogger().addHandler(logging.StreamHandler())
    logging.info(json.dumps(args.__dict__))

    # Image preprocessing
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.ToPILImage(),
        transforms.RandomResizedCrop(args.crop_size,
                                     scale=(1.00, 1.2),
                                     ratio=(0.75, 1.3333333333333333)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])

    # Load vocabulary wrapper.
    vocab = load_vocab(args.vocab_path)

    # Build data loader
    logging.info("Building data loader...")
    train_sampler = None
    val_sampler = None
    if os.path.exists(args.train_dataset_weights):
        train_weights = json.load(open(args.train_dataset_weights))
        train_weights = torch.DoubleTensor(train_weights)
        train_sampler = torch.utils.data.sampler.WeightedRandomSampler(
            train_weights, len(train_weights))
    if os.path.exists(args.val_dataset_weights):
        val_weights = json.load(open(args.val_dataset_weights))
        val_weights = torch.DoubleTensor(val_weights)
        val_sampler = torch.utils.data.sampler.WeightedRandomSampler(
            val_weights, len(val_weights))
    data_loader = get_loader(args.dataset,
                             transform,
                             args.batch_size,
                             shuffle=True,
                             num_workers=args.num_workers,
                             max_examples=args.max_examples,
                             sampler=train_sampler)
    val_data_loader = get_loader(args.val_dataset,
                                 transform,
                                 args.batch_size,
                                 shuffle=False,
                                 num_workers=args.num_workers,
                                 max_examples=args.max_examples,
                                 sampler=val_sampler)
    logging.info("Done")

    vqa = create_model(args, vocab)

    if args.load_model is not None:
        vqa.load_state_dict(torch.load(args.load_model))
    logging.info("Done")

    # Loss criterion.
    pad = vocab(vocab.SYM_PAD)  # Set loss weight for 'pad' symbol to 0
    criterion = nn.CrossEntropyLoss(ignore_index=pad)
    l2_criterion = nn.MSELoss()

    # Setup GPUs.
    if torch.cuda.is_available():
        logging.info("Using available GPU...")
        vqa.cuda()
        criterion.cuda()
        l2_criterion.cuda()

    # Parameters to train.
    gen_params = vqa.generator_parameters()
    info_params = vqa.info_parameters()
    learning_rate = args.learning_rate
    info_learning_rate = args.info_learning_rate
    gen_optimizer = torch.optim.Adam(gen_params, lr=learning_rate)
    info_optimizer = torch.optim.Adam(info_params, lr=info_learning_rate)
    scheduler = ReduceLROnPlateau(optimizer=gen_optimizer,
                                  mode='min',
                                  factor=0.1,
                                  patience=args.patience,
                                  verbose=True,
                                  min_lr=1e-7)
    info_scheduler = ReduceLROnPlateau(optimizer=info_optimizer,
                                       mode='min',
                                       factor=0.1,
                                       patience=args.patience,
                                       verbose=True,
                                       min_lr=1e-7)

    # Train the model.
    total_steps = len(data_loader)
    start_time = time.time()
    n_steps = 0

    # Optional losses. Initialized here for logging.
    recon_question_loss = 0.0
    recon_image_loss = 0.0
    kl_loss = 0.0
    z_t_kl = 0.0
    t_kl = 0.0
    for epoch in range(args.num_epochs):
        for i, (images, questions, answers,
                aindices) in enumerate(data_loader):
            n_steps += 1

            # Set mini-batch dataset.
            if torch.cuda.is_available():
                images = images.cuda()
                questions = questions.cuda()
                answers = answers.cuda()
                aindices = aindices.cuda()
            qlengths = process_lengths(questions)
            qlengths.sort(reverse=True)

            # Eval now.
            if (args.eval_every_n_steps is not None
                    and n_steps >= args.eval_every_n_steps
                    and n_steps % args.eval_every_n_steps == 0):
                run_eval(vqa, val_data_loader, criterion, l2_criterion, args,
                         epoch, scheduler, info_scheduler)
                compare_outputs(images, answers, questions, qlengths, vqa,
                                vocab, logging, args)

            # Forward.
            vqa.train()
            gen_optimizer.zero_grad()
            info_optimizer.zero_grad()
            image_features = vqa.encode_images(images)
            question_features = vqa.encode_questions(questions, qlengths)

            # Question generation.
            mus, logvars = vqa.encode_into_z(image_features, question_features)
            zs = vqa.reparameterize(mus, logvars)
            (outputs, _, _) = vqa.decode_answers(image_features,
                                                 zs,
                                                 answers=answers,
                                                 teacher_forcing_ratio=1.0)

            # Reorder the questions based on length.
            answers = torch.index_select(answers, 0, aindices)

            # Ignoring the start token.
            answers = answers[:, 1:]
            alengths = process_lengths(answers)
            alengths.sort(reverse=True)

            # Convert the output from MAX_LEN list of (BATCH x VOCAB) ->
            # (BATCH x MAX_LEN x VOCAB).
            outputs = [o.unsqueeze(1) for o in outputs]
            outputs = torch.cat(outputs, dim=1)
            outputs = torch.index_select(outputs, 0, aindices)

            # Calculate the generation loss.
            targets = pack_padded_sequence(answers, alengths,
                                           batch_first=True)[0]
            outputs = pack_padded_sequence(outputs, alengths,
                                           batch_first=True)[0]
            gen_loss = criterion(outputs, targets)
            total_loss = 0.0
            total_loss += args.lambda_gen * gen_loss
            gen_loss = gen_loss.item()

            # Variational loss.
            kl_loss = gaussian_KL_loss(mus, logvars)
            total_loss += args.lambda_z * kl_loss
            kl_loss = kl_loss.item()

            # Generator Backprop.
            total_loss.backward()
            gen_optimizer.step()

            # Reconstruction loss.
            recon_image_loss = 0.0
            recon_question_loss = 0.0
            if not args.no_question_recon or not args.no_image_recon:
                total_info_loss = 0.0
                gen_optimizer.zero_grad()
                info_optimizer.zero_grad()
                question_targets = question_features.detach()
                image_targets = image_features.detach()

                recon_image_features, recon_question_features = vqa.reconstruct_inputs(
                    image_targets, question_targets)

                # Answer reconstruction loss.
                if not args.no_question_recon:
                    recon_q_loss = l2_criterion(recon_question_features,
                                                question_targets)
                    total_info_loss += args.lambda_a * recon_q_loss
                    recon_question_loss = recon_q_loss.item()

                # Image reconstruction loss.
                if not args.no_image_recon:
                    recon_i_loss = l2_criterion(recon_image_features,
                                                image_targets)
                    total_info_loss += args.lambda_i * recon_i_loss
                    recon_image_loss = recon_i_loss.item()

                # Info backprop.
                total_info_loss.backward()
                info_optimizer.step()

            # Print log info
            if i % args.log_step == 0:
                delta_time = time.time() - start_time
                start_time = time.time()
                logging.info(
                    'Time: %.4f, Epoch [%d/%d], Step [%d/%d], '
                    'LR: %f, gen: %.4f, KL: %.4f, '
                    'I-recon: %.4f, Q-recon: %.4f' %
                    (delta_time, epoch, args.num_epochs, i, total_steps,
                     gen_optimizer.param_groups[0]['lr'], gen_loss, kl_loss,
                     recon_image_loss, recon_question_loss))

            # Save the models
            if args.save_step is not None and (i + 1) % args.save_step == 0:
                torch.save(
                    vqa.state_dict(),
                    os.path.join(args.model_path,
                                 'vqa-tf-%d-%d.pkl' % (epoch + 1, i + 1)))

        torch.save(
            vqa.state_dict(),
            os.path.join(args.model_path, 'vqa-tf-%d.pkl' % (epoch + 1)))

        # Evaluation and learning rate updates.
        run_eval(vqa, val_data_loader, criterion, l2_criterion, args, epoch,
                 scheduler, info_scheduler)
Exemplo n.º 4
0
def train(args):
    # Create model directory
    if not os.path.exists(args.model_path):
        os.makedirs(args.model_path)

    # Save the arguments.
    with open(os.path.join(args.model_path, 'args.json'), 'w') as args_file:
        json.dump(args.__dict__, args_file)

    # Config logging.
    log_format = '%(levelname)-8s %(message)s'
    log_file_name = 'train_' + args.train_log_file_suffix + '.log'
    logfile = os.path.join(args.model_path, log_file_name)
    logging.basicConfig(filename=logfile,
                        level=logging.INFO,
                        format=log_format)
    logging.getLogger().addHandler(logging.StreamHandler())
    logging.info(json.dumps(args.__dict__))

    # Image preprocessing
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.ToPILImage(),
        transforms.RandomResizedCrop(args.crop_size,
                                     scale=(1.00, 1.2),
                                     ratio=(0.75, 1.3333333333333333)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])

    # Load vocabulary wrapper.
    vocab = load_vocab(args.vocab_path)

    # Load the category types.
    cat2name = json.load(open(args.cat2name))

    # Build data loader
    logging.info("Building data loader...")
    train_sampler = None
    val_sampler = None
    if os.path.exists(args.train_dataset_weights):
        train_weights = json.load(open(args.train_dataset_weights))
        train_weights = torch.DoubleTensor(train_weights)
        train_sampler = torch.utils.data.sampler.WeightedRandomSampler(
            train_weights, len(train_weights))

    if os.path.exists(args.val_dataset_weights):
        val_weights = json.load(open(args.val_dataset_weights))
        val_weights = torch.DoubleTensor(val_weights)
        val_sampler = torch.utils.data.sampler.WeightedRandomSampler(
            val_weights, len(val_weights))

    data_loader = get_loader(args.dataset,
                             transform,
                             args.batch_size,
                             shuffle=False,
                             num_workers=args.num_workers,
                             max_examples=args.max_examples,
                             sampler=train_sampler)
    val_data_loader = get_loader(args.val_dataset,
                                 transform,
                                 args.batch_size,
                                 shuffle=False,
                                 num_workers=args.num_workers,
                                 max_examples=args.max_examples,
                                 sampler=val_sampler)

    print('Done loading data ............................')

    logging.info("Done")

    vqg = create_model(args, vocab)
    if args.load_model is not None:
        vqg.load_state_dict(torch.load(args.load_model))

    logging.info("Done")

    # Loss criterion.
    pad = vocab(vocab.SYM_PAD)  # Set loss weight for 'pad' symbol to 0
    criterion = nn.CrossEntropyLoss()
    criterion2 = nn.MultiMarginLoss().cuda()
    l2_criterion = nn.MSELoss()

    alpha = None

    if (args.bayes):
        alpha = vqg.alpha

    # Setup GPUs.
    if torch.cuda.is_available():
        logging.info("Using available GPU...")
        vqg.cuda()
        criterion.cuda()
        l2_criterion.cuda()
        if (alpha is not None):
            alpha.cuda()

    gen_params = vqg.generator_parameters()
    info_params = vqg.info_parameters()

    learning_rate = args.learning_rate
    info_learning_rate = args.info_learning_rate

    gen_optimizer = torch.optim.Adam(gen_params, lr=learning_rate)
    info_optimizer = torch.optim.Adam(info_params, lr=info_learning_rate)

    if (args.step_two):
        cycle_params = vqg.cycle_params()
        cycle_optimizer = torch.optim.Adam(cycle_params, lr=learning_rate)

    if (args.center_loss):
        center_loss = CenterLoss(num_classes=args.num_categories,
                                 feat_dim=args.z_size,
                                 use_gpu=True)
        optimizer_centloss = torch.optim.SGD(center_loss.parameters(), lr=0.5)

    scheduler = ReduceLROnPlateau(optimizer=gen_optimizer,
                                  mode='min',
                                  factor=0.5,
                                  patience=args.patience,
                                  verbose=True,
                                  min_lr=1e-7)
    cycle_scheduler = ReduceLROnPlateau(optimizer=gen_optimizer,
                                        mode='min',
                                        factor=0.99,
                                        patience=args.patience,
                                        verbose=True,
                                        min_lr=1e-7)
    info_scheduler = ReduceLROnPlateau(optimizer=info_optimizer,
                                       mode='min',
                                       factor=0.5,
                                       patience=args.patience,
                                       verbose=True,
                                       min_lr=1e-7)

    # Train the model.
    total_steps = len(data_loader)
    start_time = time.time()
    n_steps = 0

    # Optional losses. Initialized here for logging.
    recon_category_loss = 0.0
    recon_image_loss = 0.0
    kl_loss = 0.0
    category_cycle_loss = 0.0
    regularisation_loss = 0.0
    c_loss = 0.0
    cycle_loss = 0.0

    if (args.step_two):
        category_cycle_loss = 0.0

    if (args.bayes):
        regularisation_loss = 0.0

    if (args.center_loss):
        loss_center = 0.0
        c_loss = 0.0

    for epoch in range(args.num_epochs):
        for i, (images, questions, answers, categories,
                qindices) in enumerate(data_loader):
            n_steps += 1
            ''' remove answers from dataloader later '''

            # Set mini-batch dataset.
            if torch.cuda.is_available():
                images = images.cuda()
                questions = questions.cuda()
                answers = answers.cuda()
                categories = categories.cuda()
                qindices = qindices.cuda()
                if (args.bayes):
                    alpha = alpha.cuda()

            # Eval now.
            if (args.eval_every_n_steps is not None
                    and n_steps >= args.eval_every_n_steps
                    and n_steps % args.eval_every_n_steps == 0):
                run_eval(vqg, val_data_loader, criterion, l2_criterion, args,
                         epoch, scheduler, info_scheduler)
                compare_outputs(images, questions, answers, categories, vqg,
                                vocab, logging, cat2name, args)

            # Forward.
            vqg.train()
            gen_optimizer.zero_grad()
            info_optimizer.zero_grad()
            if (args.step_two):
                cycle_optimizer.zero_grad()
            if (args.center_loss):
                optimizer_centloss.zero_grad()

            image_features = vqg.encode_images(images)
            category_features = vqg.encode_categories(categories)

            # Question generation.
            t_mus, t_logvars, ts = vqg.encode_into_t(image_features,
                                                     category_features)

            if (args.center_loss):
                loss_center = 0.0
                c_loss = center_loss(ts, categories)
                loss_center += args.lambda_centerloss * c_loss
                c_loss = c_loss.item()
                loss_center.backward(retain_graph=True)

                for param in center_loss.parameters():
                    param.grad.data *= (1. / args.lambda_centerloss)

                optimizer_centloss.step()

            qlengths_prev = process_lengths(questions)

            (outputs, _,
             _), pred_ques = vqg.decode_questions(image_features,
                                                  ts,
                                                  questions=questions,
                                                  teacher_forcing_ratio=1.0)

            # Reorder the questions based on length.
            questions = torch.index_select(questions, 0, qindices)

            # Ignoring the start token.
            questions = questions[:, 1:]
            qlengths = process_lengths(questions)

            # Convert the output from MAX_LEN list of (BATCH x VOCAB) ->
            # (BATCH x MAX_LEN x VOCAB).

            outputs = [o.unsqueeze(1) for o in outputs]
            outputs = torch.cat(outputs, dim=1)

            outputs = torch.index_select(outputs, 0, qindices)

            if (args.step_two):
                category_cycle_loss = 0.0
                category_cycle = vqg.encode_questions(pred_ques, qlengths)
                cycle_loss = criterion(category_cycle, categories)
                category_cycle_loss += args.lambda_c_cycle * cycle_loss
                cycle_loss = cycle_loss.item()
                category_cycle_loss.backward(retain_graph=True)
                cycle_optimizer.step()

            # Calculate the generation loss.
            targets = pack_padded_sequence(questions,
                                           qlengths,
                                           batch_first=True)[0]
            outputs = pack_padded_sequence(outputs, qlengths,
                                           batch_first=True)[0]

            gen_loss = criterion(outputs, targets)
            total_loss = 0.0
            total_loss += args.lambda_gen * gen_loss
            gen_loss = gen_loss.item()

            # Variational loss.
            if (args.bayes):
                kl_loss = -0.5 * torch.sum(1 + t_logvars + alpha.pow(2).log() -
                                           alpha.pow(2) *
                                           (t_mus.pow(2) + t_logvars.exp()))
                regularisation_loss = l2_criterion(alpha.pow(-1),
                                                   torch.ones_like(alpha))
                total_loss += args.lambda_t * kl_loss + args.lambda_reg * regularisation_loss
                kl_loss = kl_loss.item()
                regularisation_loss = regularisation_loss.item()
            else:
                kl_loss = gaussian_KL_loss(t_mus, t_logvars)
                total_loss += args.lambda_t * kl_loss
                kl_loss = kl_loss.item()

            # Generator Backprop.
            total_loss.backward(retain_graph=True)
            gen_optimizer.step()

            # Reconstruction loss.
            recon_image_loss = 0.0
            recon_category_loss = 0.0

            if not args.no_category_space or not args.no_image_recon:
                total_info_loss = 0.0
                category_targets = category_features.detach()
                image_targets = image_features.detach()
                recon_image_features, recon_category_features = vqg.reconstruct_inputs(
                    image_targets, category_targets)

                # Category reconstruction loss.
                if not args.no_category_space:
                    recon_c_loss = l2_criterion(
                        recon_category_features,
                        category_targets)  # changed to criterion2
                    total_info_loss += args.lambda_c * recon_c_loss
                    recon_category_loss = recon_c_loss.item()

                # Image reconstruction loss.
                if not args.no_image_recon:
                    recon_i_loss = l2_criterion(recon_image_features,
                                                image_targets)
                    total_info_loss += args.lambda_i * recon_i_loss
                    recon_image_loss = recon_i_loss.item()

                # Info backprop.
                total_info_loss.backward()
                info_optimizer.step()

            # Print log info
            if i % args.log_step == 0:
                delta_time = time.time() - start_time
                start_time = time.time()
                logging.info(
                    'Time: %.4f, Epoch [%d/%d], Step [%d/%d], '
                    'LR: %f, Center-Loss: %.4f, KL: %.4f, '
                    'I-recon: %.4f, C-recon: %.4f, C-cycle: %.4f, Regularisation: %.4f'
                    % (delta_time, epoch, args.num_epochs, i, total_steps,
                       gen_optimizer.param_groups[0]['lr'], c_loss, kl_loss,
                       recon_image_loss, recon_category_loss, cycle_loss,
                       regularisation_loss))

            # Save the models
            if args.save_step is not None and (i + 1) % args.save_step == 0:
                torch.save(
                    vqg.state_dict(),
                    os.path.join(args.model_path,
                                 'vqg-tf-%d-%d.pkl' % (epoch + 1, i + 1)))

        torch.save(
            vqg.state_dict(),
            os.path.join(args.model_path, 'vqg-tf-%d.pkl' % (epoch + 1)))

        torch.save(
            center_loss.state_dict(),
            os.path.join(args.model_path,
                         'closs-tf-%d-%d.pkl' % (epoch + 1, i + 1)))

        # Evaluation and learning rate updates.
        run_eval(vqg, val_data_loader, criterion, l2_criterion, args, epoch,
                 scheduler, info_scheduler)