def evaluate(model, data_loader, vocab, args, params):
    model.eval()
    preds = []
    gts = []
    bar = progressbar.ProgressBar(maxval=len(data_loader))

    for i, (sentences, labels, _) in enumerate(data_loader):
        #n_steps += 1
        # Set mini-batch dataset.
        if torch.cuda.is_available():
            sentences = sentences.cuda()
            labels = labels.cuda()
        lengths = process_lengths(sentences)
        lengths.sort(reverse=True)

        predictions = model(sentences, lengths)
        predictions = torch.max(predictions, 1)[1]
        for p in predictions:
            preds.append(p.item())

        for l in labels:
            gts.append(l.item())

    print('=' * 80)
    print('GROUND TRUTH')
    print(gts[:args.num_show])
    print('-' * 80)
    print('PREDICTIONS')
    print(preds[:args.num_show])
    print('=' * 80)
    acc = binary_accuracy(preds, gts)

    print('Acc : ', acc)
    return acc, gts, preds
Example #2
0
def evaluate(vocab, vqa, data_loader, criterion, epoch, args):
    """Calculates vqg average loss on data_loader.

    Args:
        vocab: questions and answers vocabulary.
        vqa: visual question answering model.
        data_loader: Iterator for the data.
        criterion: The criterion function used to evaluate the loss.
        args: ArgumentParser object.

    Returns:
        A float value of average loss.
    """
    gts, gens, qs = [], [], []
    vqa.eval()
    total_loss = 0.0
    total_correct = 0.0
    iterations = 0
    total_steps = len(data_loader)
    if args.eval_steps is not None:
        total_steps = min(len(data_loader), args.eval_steps)
    start_time = time.time()
    for i, (feats, questions, categories) in enumerate(data_loader):

        # Set mini-batch dataset.
        if torch.cuda.is_available():
            feats = feats.cuda()
            questions = questions.cuda()
            categories = categories.cuda()
        qlengths = process_lengths(questions)

        # Forward.
        outputs = vqa(feats, questions, qlengths)
        loss = criterion(outputs, categories)
        preds = outputs.max(1)[1]

        # Backprop and optimize.
        total_loss += loss.item()
        total_correct += accuracy(preds, categories)
        iterations += 1

        # Quit after eval_steps.
        if args.eval_steps is not None and i >= args.eval_steps:
            break
        q, gen, gt = parse_outputs(preds, questions, categories, vocab)
        gts.extend(gt)
        gens.extend(gen)
        qs.extend(q)

        # Print logs.
        if i % args.log_step == 0:
            delta_time = time.time() - start_time
            start_time = time.time()
            logging.info('Time: %.4f, Step [%d/%d], '
                         'Avg Loss: %.4f, Avg Acc: %.4f' %
                         (delta_time, i, total_steps, total_loss / iterations,
                          total_correct / iterations))
    # Compare model reconstruction to target
    compare_outputs(gens, qs, gts, logging)
    return total_loss / iterations
Example #3
0
def evaluate(model, data_loader, criterion, args):

    model.eval()
    epoch_loss = 0
    epoch_acc = 0
    for i, (sentences, labels, qindices) in enumerate(data_loader):
        #n_steps += 1
        # Set mini-batch dataset.
        if torch.cuda.is_available():
            sentences = sentences.cuda()
            labels = labels.cuda()
            qindices = qindices.cuda()
        lengths = process_lengths(sentences)
        lengths.sort(reverse=True)
        #convert to 1D tensor
        predictions = model(sentences, lengths)

        loss = criterion(predictions, labels)
        #compute the binary accuracy
        acc = binary_accuracy(predictions, labels)
        #backpropage the loss and compute the gradients
        #loss and accuracy
        epoch_loss += loss.item()
        epoch_acc += acc.item()
    logging.info('\t Val-Loss: %.4f | Val-Acc: %.2f ' %
                 (loss.item(), acc.item() * 100))
Example #4
0
def evaluate(vqa, data_loader, vocab, args, params):
    """Runs BLEU, METEOR, CIDEr and distinct n-gram scores.

    Args:
        vqa: question answering model.
        data_loader: Iterator for the data.
        args: ArgumentParser object.
        params: ArgumentParser object.

    Returns:
        A float value of average loss.
    """
    vqa.eval()
    preds = []
    gts = []
    total_steps = len(data_loader)
    if args.eval_steps is not None:
        total_steps = min(len(data_loader), args.eval_steps)
    bar = progressbar.ProgressBar(maxval=total_steps)
    for iterations, (images, questions, categories) in enumerate(data_loader):

        # Set mini-batch dataset
        if torch.cuda.is_available():
            images = images.cuda()
            questions = questions.cuda()
            categories = categories.cuda()
        qlengths = process_lengths(questions)

        # Predict.
        outputs = vqa(images, questions, qlengths)
        out = outputs.max(1)[1]

        _, pred, gt = parse_outputs(out, questions, categories, vocab)
        gts.extend(gt)
        preds.extend(pred)

        bar.update(iterations)
        if args.eval_steps is not None and iterations >= args.eval_steps:
            break

    print('=' * 80)
    print('GROUND TRUTH')
    print(gts[:args.num_show])
    print('-' * 80)
    print('PREDICTIONS')
    print(preds[:args.num_show])
    print('=' * 80)
    scores = accuracy_score(gts, preds)
    return scores, gts, preds
Example #5
0
def evaluate(vqg, data_loader, vocab, args, params):
    """Runs BLEU, METEOR, CIDEr and distinct n-gram scores.

    Args:
        vqg: question generation model.
        data_loader: Iterator for the data.
        args: ArgumentParser object.
        params: ArgumentParser object.

    Returns:
        A float value of average loss.
    """
    vqg.eval()
    nlge = NLGEval(no_glove=True, no_skipthoughts=True)
    preds = []
    gts = []
    bar = progressbar.ProgressBar(maxval=len(data_loader))
    for iterations, (images, questions, answers,
                     categories, _) in enumerate(data_loader):

        # Set mini-batch dataset
        if torch.cuda.is_available():
            images = images.cuda()
            answers = answers.cuda()
            categories = categories.cuda()
        alengths = process_lengths(answers)

        # Predict.
        if args.from_answer:
            outputs = vqg.predict_from_answer(images, answers, alengths)
        else:
            outputs = vqg.predict_from_category(images, categories)
        for i in range(images.size(0)):
            output = vocab.tokens_to_words(outputs[i])
            preds.append(output)

            question = vocab.tokens_to_words(questions[i])
            gts.append(question)
        bar.update(iterations)
        
    print '='*80
    print 'GROUND TRUTH'
    print gts[:args.num_show]
    print '-'*80
    print 'PREDICTIONS'
    print preds[:args.num_show]
    print '='*80
    scores = nlge.compute_metrics(ref_list=[gts], hyp_list=preds)
    return scores, gts, preds
Example #6
0
def evaluate(vqa, data_loader, criterion, l2_criterion, args):
    """Calculates vqa average loss on data_loader.

    Args:
        vqa: visual question answering model.
        data_loader: Iterator for the data.
        criterion: The loss function used to evaluate the loss.
        l2_criterion: The loss function used to evaluate the l2 loss.
        args: ArgumentParser object.

    Returns:
        A float value of average loss.
    """
    vqa.eval()
    total_gen_loss = 0.0
    total_kl = 0.0
    total_recon_image_loss = 0.0
    total_recon_question_loss = 0.0
    total_z_t_kl = 0.0
    total_t_kl_loss = 0.0
    total_steps = len(data_loader)
    if args.eval_steps is not None:
        total_steps = min(len(data_loader), args.eval_steps)
    start_time = time.time()
    for iterations, (images, questions, answers,
                     aindices) in enumerate(data_loader):

        # Set mini-batch dataset
        if torch.cuda.is_available():
            images = images.cuda()
            questions = questions.cuda()
            answers = answers.cuda()
            aindices = aindices.cuda()
        qlengths = process_lengths(questions)
        qlengths.sort(reverse=True)

        # Forward, Backward and Optimize
        image_features = vqa.encode_images(images)
        question_features = vqa.encode_questions(questions, qlengths)
        mus, logvars = vqa.encode_into_z(image_features, question_features)
        zs = vqa.reparameterize(mus, logvars)
        (outputs, _, other) = vqa.decode_answers(image_features,
                                                 zs,
                                                 answers=answers,
                                                 teacher_forcing_ratio=1.0)

        # Reorder the questions based on length.
        answers = torch.index_select(answers, 0, aindices)

        # Ignoring the start token.
        answers = answers[:, 1:]
        alengths = process_lengths(answers)
        alengths.sort(reverse=True)

        # Convert the output from MAX_LEN list of (BATCH x VOCAB) ->
        # (BATCH x MAX_LEN x VOCAB).
        outputs = [o.unsqueeze(1) for o in outputs]
        outputs = torch.cat(outputs, dim=1)
        outputs = torch.index_select(outputs, 0, aindices)

        # Calculate the loss.
        targets = pack_padded_sequence(answers, alengths, batch_first=True)[0]
        outputs = pack_padded_sequence(outputs, alengths, batch_first=True)[0]
        gen_loss = criterion(outputs, targets)
        total_gen_loss += gen_loss.data.item()

        # Get KL loss if it exists.
        kl_loss = gaussian_KL_loss(mus, logvars)
        total_kl += kl_loss.item()

        # Reconstruction.
        if not args.no_image_recon or not args.no_question_recon:
            image_targets = image_features.detach()
            question_targets = question_features.detach()
            recon_image_features, recon_question_features = vqa.reconstruct_inputs(
                image_targets, question_targets)

            if not args.no_image_recon:
                recon_i_loss = l2_criterion(recon_image_features,
                                            image_targets)
                total_recon_image_loss += recon_i_loss.item()
            if not args.no_question_recon:
                recon_q_loss = l2_criterion(recon_question_features,
                                            question_targets)
                total_recon_question_loss += recon_q_loss.item()

        # Quit after eval_steps.
        if args.eval_steps is not None and iterations >= args.eval_steps:
            break

        # Print logs
        if iterations % args.log_step == 0:
            delta_time = time.time() - start_time
            start_time = time.time()
            logging.info('Time: %.4f, Step [%d/%d], gen loss: %.4f, '
                         'KL: %.4f, I-recon: %.4f, Q-recon: %.4f' %
                         (delta_time, iterations, total_steps, total_gen_loss /
                          (iterations + 1), total_kl /
                          (iterations + 1), total_recon_image_loss /
                          (iterations + 1), total_recon_question_loss /
                          (iterations + 1)))
    total_info_loss = total_recon_image_loss + total_recon_question_loss
    return total_gen_loss / (iterations + 1), total_info_loss / (iterations +
                                                                 1)
Example #7
0
def train(args):
    # Create model directory
    if not os.path.exists(args.model_path):
        os.makedirs(args.model_path)

    # Save the arguments.
    with open(os.path.join(args.model_path, 'args.json'), 'w') as args_file:
        json.dump(args.__dict__, args_file)

    # Config logging.
    log_format = '%(levelname)-8s %(message)s'
    logfile = os.path.join(args.model_path, 'train.log')
    logging.basicConfig(filename=logfile,
                        level=logging.INFO,
                        format=log_format)
    logging.getLogger().addHandler(logging.StreamHandler())
    logging.info(json.dumps(args.__dict__))

    # Image preprocessing
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.ToPILImage(),
        transforms.RandomResizedCrop(args.crop_size,
                                     scale=(1.00, 1.2),
                                     ratio=(0.75, 1.3333333333333333)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])

    # Load vocabulary wrapper.
    vocab = load_vocab(args.vocab_path)

    # Build data loader
    logging.info("Building data loader...")
    train_sampler = None
    val_sampler = None
    if os.path.exists(args.train_dataset_weights):
        train_weights = json.load(open(args.train_dataset_weights))
        train_weights = torch.DoubleTensor(train_weights)
        train_sampler = torch.utils.data.sampler.WeightedRandomSampler(
            train_weights, len(train_weights))
    if os.path.exists(args.val_dataset_weights):
        val_weights = json.load(open(args.val_dataset_weights))
        val_weights = torch.DoubleTensor(val_weights)
        val_sampler = torch.utils.data.sampler.WeightedRandomSampler(
            val_weights, len(val_weights))
    data_loader = get_loader(args.dataset,
                             transform,
                             args.batch_size,
                             shuffle=True,
                             num_workers=args.num_workers,
                             max_examples=args.max_examples,
                             sampler=train_sampler)
    val_data_loader = get_loader(args.val_dataset,
                                 transform,
                                 args.batch_size,
                                 shuffle=False,
                                 num_workers=args.num_workers,
                                 max_examples=args.max_examples,
                                 sampler=val_sampler)
    logging.info("Done")

    vqa = create_model(args, vocab)

    if args.load_model is not None:
        vqa.load_state_dict(torch.load(args.load_model))
    logging.info("Done")

    # Loss criterion.
    pad = vocab(vocab.SYM_PAD)  # Set loss weight for 'pad' symbol to 0
    criterion = nn.CrossEntropyLoss(ignore_index=pad)
    l2_criterion = nn.MSELoss()

    # Setup GPUs.
    if torch.cuda.is_available():
        logging.info("Using available GPU...")
        vqa.cuda()
        criterion.cuda()
        l2_criterion.cuda()

    # Parameters to train.
    gen_params = vqa.generator_parameters()
    info_params = vqa.info_parameters()
    learning_rate = args.learning_rate
    info_learning_rate = args.info_learning_rate
    gen_optimizer = torch.optim.Adam(gen_params, lr=learning_rate)
    info_optimizer = torch.optim.Adam(info_params, lr=info_learning_rate)
    scheduler = ReduceLROnPlateau(optimizer=gen_optimizer,
                                  mode='min',
                                  factor=0.1,
                                  patience=args.patience,
                                  verbose=True,
                                  min_lr=1e-7)
    info_scheduler = ReduceLROnPlateau(optimizer=info_optimizer,
                                       mode='min',
                                       factor=0.1,
                                       patience=args.patience,
                                       verbose=True,
                                       min_lr=1e-7)

    # Train the model.
    total_steps = len(data_loader)
    start_time = time.time()
    n_steps = 0

    # Optional losses. Initialized here for logging.
    recon_question_loss = 0.0
    recon_image_loss = 0.0
    kl_loss = 0.0
    z_t_kl = 0.0
    t_kl = 0.0
    for epoch in range(args.num_epochs):
        for i, (images, questions, answers,
                aindices) in enumerate(data_loader):
            n_steps += 1

            # Set mini-batch dataset.
            if torch.cuda.is_available():
                images = images.cuda()
                questions = questions.cuda()
                answers = answers.cuda()
                aindices = aindices.cuda()
            qlengths = process_lengths(questions)
            qlengths.sort(reverse=True)

            # Eval now.
            if (args.eval_every_n_steps is not None
                    and n_steps >= args.eval_every_n_steps
                    and n_steps % args.eval_every_n_steps == 0):
                run_eval(vqa, val_data_loader, criterion, l2_criterion, args,
                         epoch, scheduler, info_scheduler)
                compare_outputs(images, answers, questions, qlengths, vqa,
                                vocab, logging, args)

            # Forward.
            vqa.train()
            gen_optimizer.zero_grad()
            info_optimizer.zero_grad()
            image_features = vqa.encode_images(images)
            question_features = vqa.encode_questions(questions, qlengths)

            # Question generation.
            mus, logvars = vqa.encode_into_z(image_features, question_features)
            zs = vqa.reparameterize(mus, logvars)
            (outputs, _, _) = vqa.decode_answers(image_features,
                                                 zs,
                                                 answers=answers,
                                                 teacher_forcing_ratio=1.0)

            # Reorder the questions based on length.
            answers = torch.index_select(answers, 0, aindices)

            # Ignoring the start token.
            answers = answers[:, 1:]
            alengths = process_lengths(answers)
            alengths.sort(reverse=True)

            # Convert the output from MAX_LEN list of (BATCH x VOCAB) ->
            # (BATCH x MAX_LEN x VOCAB).
            outputs = [o.unsqueeze(1) for o in outputs]
            outputs = torch.cat(outputs, dim=1)
            outputs = torch.index_select(outputs, 0, aindices)

            # Calculate the generation loss.
            targets = pack_padded_sequence(answers, alengths,
                                           batch_first=True)[0]
            outputs = pack_padded_sequence(outputs, alengths,
                                           batch_first=True)[0]
            gen_loss = criterion(outputs, targets)
            total_loss = 0.0
            total_loss += args.lambda_gen * gen_loss
            gen_loss = gen_loss.item()

            # Variational loss.
            kl_loss = gaussian_KL_loss(mus, logvars)
            total_loss += args.lambda_z * kl_loss
            kl_loss = kl_loss.item()

            # Generator Backprop.
            total_loss.backward()
            gen_optimizer.step()

            # Reconstruction loss.
            recon_image_loss = 0.0
            recon_question_loss = 0.0
            if not args.no_question_recon or not args.no_image_recon:
                total_info_loss = 0.0
                gen_optimizer.zero_grad()
                info_optimizer.zero_grad()
                question_targets = question_features.detach()
                image_targets = image_features.detach()

                recon_image_features, recon_question_features = vqa.reconstruct_inputs(
                    image_targets, question_targets)

                # Answer reconstruction loss.
                if not args.no_question_recon:
                    recon_q_loss = l2_criterion(recon_question_features,
                                                question_targets)
                    total_info_loss += args.lambda_a * recon_q_loss
                    recon_question_loss = recon_q_loss.item()

                # Image reconstruction loss.
                if not args.no_image_recon:
                    recon_i_loss = l2_criterion(recon_image_features,
                                                image_targets)
                    total_info_loss += args.lambda_i * recon_i_loss
                    recon_image_loss = recon_i_loss.item()

                # Info backprop.
                total_info_loss.backward()
                info_optimizer.step()

            # Print log info
            if i % args.log_step == 0:
                delta_time = time.time() - start_time
                start_time = time.time()
                logging.info(
                    'Time: %.4f, Epoch [%d/%d], Step [%d/%d], '
                    'LR: %f, gen: %.4f, KL: %.4f, '
                    'I-recon: %.4f, Q-recon: %.4f' %
                    (delta_time, epoch, args.num_epochs, i, total_steps,
                     gen_optimizer.param_groups[0]['lr'], gen_loss, kl_loss,
                     recon_image_loss, recon_question_loss))

            # Save the models
            if args.save_step is not None and (i + 1) % args.save_step == 0:
                torch.save(
                    vqa.state_dict(),
                    os.path.join(args.model_path,
                                 'vqa-tf-%d-%d.pkl' % (epoch + 1, i + 1)))

        torch.save(
            vqa.state_dict(),
            os.path.join(args.model_path, 'vqa-tf-%d.pkl' % (epoch + 1)))

        # Evaluation and learning rate updates.
        run_eval(vqa, val_data_loader, criterion, l2_criterion, args, epoch,
                 scheduler, info_scheduler)
Example #8
0
def evaluate(vqg, data_loader, criterion, l2_criterion, args):
    vqg.eval()

    if (args.bayes):
        alpha = vqg.alpha

    total_gen_loss = 0.0
    total_kl = 0.0
    total_recon_image_loss = 0.0
    total_recon_category_loss = 0.0
    total_z_t_kl = 0.0
    total_t_kl_loss = 0.0
    regularisation_loss = 0.0
    c_loss = 0.0
    category_cycle_loss = 0.0

    total_steps = len(data_loader)
    if args.eval_steps is not None:
        total_steps = min(len(data_loader), args.eval_steps)
    start_time = time.time()
    for iterations, (images, questions, answers, categories,
                     qindices) in enumerate(data_loader):
        ''' remove answers from the dataloader later '''

        # Set mini-batch dataset
        if torch.cuda.is_available():
            images = images.cuda()
            questions = questions.cuda()
            answers = answers.cuda()
            categories = categories.cuda()
            qindices = qindices.cuda()
            if (args.bayes):
                alpha = vqg.alpha.cuda()

        # Forward, Backward and Optimize
        image_features = vqg.encode_images(images)
        category_features = vqg.encode_categories(categories)
        t_mus, t_logvars, ts = vqg.encode_into_t(image_features,
                                                 category_features)
        (outputs, _,
         other), pred_ques = vqg.decode_questions(image_features,
                                                  ts,
                                                  questions=questions,
                                                  teacher_forcing_ratio=1.0)

        # Reorder the questions based on length.
        questions = torch.index_select(questions, 0, qindices)
        total_loss = 0.0

        # Ignoring the start token.
        questions = questions[:, 1:]
        qlengths = process_lengths(questions)

        # Convert the output from MAX_LEN list of (BATCH x VOCAB) ->
        # (BATCH x MAX_LEN x VOCAB).
        outputs = [o.unsqueeze(1) for o in outputs]
        outputs = torch.cat(outputs, dim=1)
        outputs = torch.index_select(outputs, 0, qindices)

        if (args.step_two):
            category_cycle = vqg.encode_questions(pred_ques, qlengths)
            category_cycle_loss = criterion(category_cycle, categories)

            category_cycle_loss = category_cycle_loss.item()
            total_loss += args.lambda_c_cycle * category_cycle_loss

        # Calculate the loss.
        targets = pack_padded_sequence(questions, qlengths,
                                       batch_first=True)[0]
        outputs = pack_padded_sequence(outputs, qlengths, batch_first=True)[0]

        gen_loss = criterion(outputs, targets)
        total_gen_loss += gen_loss.data.item()

        # Get KL loss if it exists.
        if (args.bayes):
            regularisation_loss = l2_criterion(alpha.pow(-1),
                                               torch.ones_like(alpha))
            kl_loss = -0.5 * torch.sum(1 + t_logvars + alpha.pow(2).log() -
                                       alpha.pow(2) *
                                       (t_mus.pow(2) + t_logvars.exp()))
            total_kl += kl_loss.item() + regularisation_loss.item()
        else:
            kl_loss = gaussian_KL_loss(t_mus, t_logvars)
            total_kl += args.lambda_t * kl_loss
            kl_loss = kl_loss.item()

        # Reconstruction.
        if not args.no_image_recon or not args.no_category_space:
            image_targets = image_features.detach()
            category_targets = category_features.detach()

            recon_image_features, recon_category_features = vqg.reconstruct_inputs(
                image_targets, category_targets)

            if not args.no_image_recon:
                recon_i_loss = l2_criterion(recon_image_features,
                                            image_targets)
                total_recon_image_loss += recon_i_loss.item()

            if not args.no_category_space:
                recon_c_loss = l2_criterion(recon_category_features,
                                            category_targets)
                total_recon_category_loss += recon_c_loss.item()

        # Quit after eval_steps.
        if args.eval_steps is not None and iterations >= args.eval_steps:
            break

        # Print logs
        if iterations % args.log_step == 0:
            delta_time = time.time() - start_time
            start_time = time.time()
            logging.info(
                'Time: %.4f, Step [%d/%d], gen loss: %.4f, '
                'KL: %.4f, I-recon: %.4f, C-recon: %.4f, C-cycle: %.4f, Regularisation: %.4f'
                % (delta_time, iterations, total_steps, total_gen_loss /
                   (iterations + 1), total_kl /
                   (iterations + 1), total_recon_image_loss /
                   (iterations + 1), total_recon_category_loss /
                   (iterations + 1), category_cycle_loss /
                   (iterations + 1), regularisation_loss / (iterations + 1)))

    total_info_loss = total_recon_image_loss + total_recon_category_loss
    return total_gen_loss / (iterations + 1), total_info_loss / (iterations +
                                                                 1)
Example #9
0
def train(args):
    # Create model directory
    if not os.path.exists(args.model_path):
        os.makedirs(args.model_path)

    # Save the arguments.
    with open(os.path.join(args.model_path, 'args.json'), 'w') as args_file:
        json.dump(args.__dict__, args_file)

    # Config logging.
    log_format = '%(levelname)-8s %(message)s'
    log_file_name = 'train_' + args.train_log_file_suffix + '.log'
    logfile = os.path.join(args.model_path, log_file_name)
    logging.basicConfig(filename=logfile,
                        level=logging.INFO,
                        format=log_format)
    logging.getLogger().addHandler(logging.StreamHandler())
    logging.info(json.dumps(args.__dict__))

    # Image preprocessing
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.ToPILImage(),
        transforms.RandomResizedCrop(args.crop_size,
                                     scale=(1.00, 1.2),
                                     ratio=(0.75, 1.3333333333333333)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])

    # Load vocabulary wrapper.
    vocab = load_vocab(args.vocab_path)

    # Load the category types.
    cat2name = json.load(open(args.cat2name))

    # Build data loader
    logging.info("Building data loader...")
    train_sampler = None
    val_sampler = None
    if os.path.exists(args.train_dataset_weights):
        train_weights = json.load(open(args.train_dataset_weights))
        train_weights = torch.DoubleTensor(train_weights)
        train_sampler = torch.utils.data.sampler.WeightedRandomSampler(
            train_weights, len(train_weights))

    if os.path.exists(args.val_dataset_weights):
        val_weights = json.load(open(args.val_dataset_weights))
        val_weights = torch.DoubleTensor(val_weights)
        val_sampler = torch.utils.data.sampler.WeightedRandomSampler(
            val_weights, len(val_weights))

    data_loader = get_loader(args.dataset,
                             transform,
                             args.batch_size,
                             shuffle=False,
                             num_workers=args.num_workers,
                             max_examples=args.max_examples,
                             sampler=train_sampler)
    val_data_loader = get_loader(args.val_dataset,
                                 transform,
                                 args.batch_size,
                                 shuffle=False,
                                 num_workers=args.num_workers,
                                 max_examples=args.max_examples,
                                 sampler=val_sampler)

    print('Done loading data ............................')

    logging.info("Done")

    vqg = create_model(args, vocab)
    if args.load_model is not None:
        vqg.load_state_dict(torch.load(args.load_model))

    logging.info("Done")

    # Loss criterion.
    pad = vocab(vocab.SYM_PAD)  # Set loss weight for 'pad' symbol to 0
    criterion = nn.CrossEntropyLoss()
    criterion2 = nn.MultiMarginLoss().cuda()
    l2_criterion = nn.MSELoss()

    alpha = None

    if (args.bayes):
        alpha = vqg.alpha

    # Setup GPUs.
    if torch.cuda.is_available():
        logging.info("Using available GPU...")
        vqg.cuda()
        criterion.cuda()
        l2_criterion.cuda()
        if (alpha is not None):
            alpha.cuda()

    gen_params = vqg.generator_parameters()
    info_params = vqg.info_parameters()

    learning_rate = args.learning_rate
    info_learning_rate = args.info_learning_rate

    gen_optimizer = torch.optim.Adam(gen_params, lr=learning_rate)
    info_optimizer = torch.optim.Adam(info_params, lr=info_learning_rate)

    if (args.step_two):
        cycle_params = vqg.cycle_params()
        cycle_optimizer = torch.optim.Adam(cycle_params, lr=learning_rate)

    if (args.center_loss):
        center_loss = CenterLoss(num_classes=args.num_categories,
                                 feat_dim=args.z_size,
                                 use_gpu=True)
        optimizer_centloss = torch.optim.SGD(center_loss.parameters(), lr=0.5)

    scheduler = ReduceLROnPlateau(optimizer=gen_optimizer,
                                  mode='min',
                                  factor=0.5,
                                  patience=args.patience,
                                  verbose=True,
                                  min_lr=1e-7)
    cycle_scheduler = ReduceLROnPlateau(optimizer=gen_optimizer,
                                        mode='min',
                                        factor=0.99,
                                        patience=args.patience,
                                        verbose=True,
                                        min_lr=1e-7)
    info_scheduler = ReduceLROnPlateau(optimizer=info_optimizer,
                                       mode='min',
                                       factor=0.5,
                                       patience=args.patience,
                                       verbose=True,
                                       min_lr=1e-7)

    # Train the model.
    total_steps = len(data_loader)
    start_time = time.time()
    n_steps = 0

    # Optional losses. Initialized here for logging.
    recon_category_loss = 0.0
    recon_image_loss = 0.0
    kl_loss = 0.0
    category_cycle_loss = 0.0
    regularisation_loss = 0.0
    c_loss = 0.0
    cycle_loss = 0.0

    if (args.step_two):
        category_cycle_loss = 0.0

    if (args.bayes):
        regularisation_loss = 0.0

    if (args.center_loss):
        loss_center = 0.0
        c_loss = 0.0

    for epoch in range(args.num_epochs):
        for i, (images, questions, answers, categories,
                qindices) in enumerate(data_loader):
            n_steps += 1
            ''' remove answers from dataloader later '''

            # Set mini-batch dataset.
            if torch.cuda.is_available():
                images = images.cuda()
                questions = questions.cuda()
                answers = answers.cuda()
                categories = categories.cuda()
                qindices = qindices.cuda()
                if (args.bayes):
                    alpha = alpha.cuda()

            # Eval now.
            if (args.eval_every_n_steps is not None
                    and n_steps >= args.eval_every_n_steps
                    and n_steps % args.eval_every_n_steps == 0):
                run_eval(vqg, val_data_loader, criterion, l2_criterion, args,
                         epoch, scheduler, info_scheduler)
                compare_outputs(images, questions, answers, categories, vqg,
                                vocab, logging, cat2name, args)

            # Forward.
            vqg.train()
            gen_optimizer.zero_grad()
            info_optimizer.zero_grad()
            if (args.step_two):
                cycle_optimizer.zero_grad()
            if (args.center_loss):
                optimizer_centloss.zero_grad()

            image_features = vqg.encode_images(images)
            category_features = vqg.encode_categories(categories)

            # Question generation.
            t_mus, t_logvars, ts = vqg.encode_into_t(image_features,
                                                     category_features)

            if (args.center_loss):
                loss_center = 0.0
                c_loss = center_loss(ts, categories)
                loss_center += args.lambda_centerloss * c_loss
                c_loss = c_loss.item()
                loss_center.backward(retain_graph=True)

                for param in center_loss.parameters():
                    param.grad.data *= (1. / args.lambda_centerloss)

                optimizer_centloss.step()

            qlengths_prev = process_lengths(questions)

            (outputs, _,
             _), pred_ques = vqg.decode_questions(image_features,
                                                  ts,
                                                  questions=questions,
                                                  teacher_forcing_ratio=1.0)

            # Reorder the questions based on length.
            questions = torch.index_select(questions, 0, qindices)

            # Ignoring the start token.
            questions = questions[:, 1:]
            qlengths = process_lengths(questions)

            # Convert the output from MAX_LEN list of (BATCH x VOCAB) ->
            # (BATCH x MAX_LEN x VOCAB).

            outputs = [o.unsqueeze(1) for o in outputs]
            outputs = torch.cat(outputs, dim=1)

            outputs = torch.index_select(outputs, 0, qindices)

            if (args.step_two):
                category_cycle_loss = 0.0
                category_cycle = vqg.encode_questions(pred_ques, qlengths)
                cycle_loss = criterion(category_cycle, categories)
                category_cycle_loss += args.lambda_c_cycle * cycle_loss
                cycle_loss = cycle_loss.item()
                category_cycle_loss.backward(retain_graph=True)
                cycle_optimizer.step()

            # Calculate the generation loss.
            targets = pack_padded_sequence(questions,
                                           qlengths,
                                           batch_first=True)[0]
            outputs = pack_padded_sequence(outputs, qlengths,
                                           batch_first=True)[0]

            gen_loss = criterion(outputs, targets)
            total_loss = 0.0
            total_loss += args.lambda_gen * gen_loss
            gen_loss = gen_loss.item()

            # Variational loss.
            if (args.bayes):
                kl_loss = -0.5 * torch.sum(1 + t_logvars + alpha.pow(2).log() -
                                           alpha.pow(2) *
                                           (t_mus.pow(2) + t_logvars.exp()))
                regularisation_loss = l2_criterion(alpha.pow(-1),
                                                   torch.ones_like(alpha))
                total_loss += args.lambda_t * kl_loss + args.lambda_reg * regularisation_loss
                kl_loss = kl_loss.item()
                regularisation_loss = regularisation_loss.item()
            else:
                kl_loss = gaussian_KL_loss(t_mus, t_logvars)
                total_loss += args.lambda_t * kl_loss
                kl_loss = kl_loss.item()

            # Generator Backprop.
            total_loss.backward(retain_graph=True)
            gen_optimizer.step()

            # Reconstruction loss.
            recon_image_loss = 0.0
            recon_category_loss = 0.0

            if not args.no_category_space or not args.no_image_recon:
                total_info_loss = 0.0
                category_targets = category_features.detach()
                image_targets = image_features.detach()
                recon_image_features, recon_category_features = vqg.reconstruct_inputs(
                    image_targets, category_targets)

                # Category reconstruction loss.
                if not args.no_category_space:
                    recon_c_loss = l2_criterion(
                        recon_category_features,
                        category_targets)  # changed to criterion2
                    total_info_loss += args.lambda_c * recon_c_loss
                    recon_category_loss = recon_c_loss.item()

                # Image reconstruction loss.
                if not args.no_image_recon:
                    recon_i_loss = l2_criterion(recon_image_features,
                                                image_targets)
                    total_info_loss += args.lambda_i * recon_i_loss
                    recon_image_loss = recon_i_loss.item()

                # Info backprop.
                total_info_loss.backward()
                info_optimizer.step()

            # Print log info
            if i % args.log_step == 0:
                delta_time = time.time() - start_time
                start_time = time.time()
                logging.info(
                    'Time: %.4f, Epoch [%d/%d], Step [%d/%d], '
                    'LR: %f, Center-Loss: %.4f, KL: %.4f, '
                    'I-recon: %.4f, C-recon: %.4f, C-cycle: %.4f, Regularisation: %.4f'
                    % (delta_time, epoch, args.num_epochs, i, total_steps,
                       gen_optimizer.param_groups[0]['lr'], c_loss, kl_loss,
                       recon_image_loss, recon_category_loss, cycle_loss,
                       regularisation_loss))

            # Save the models
            if args.save_step is not None and (i + 1) % args.save_step == 0:
                torch.save(
                    vqg.state_dict(),
                    os.path.join(args.model_path,
                                 'vqg-tf-%d-%d.pkl' % (epoch + 1, i + 1)))

        torch.save(
            vqg.state_dict(),
            os.path.join(args.model_path, 'vqg-tf-%d.pkl' % (epoch + 1)))

        torch.save(
            center_loss.state_dict(),
            os.path.join(args.model_path,
                         'closs-tf-%d-%d.pkl' % (epoch + 1, i + 1)))

        # Evaluation and learning rate updates.
        run_eval(vqg, val_data_loader, criterion, l2_criterion, args, epoch,
                 scheduler, info_scheduler)
Example #10
0
    def forward(self, images, answers, alengths=None, questions=None):
        """Passes the image and the question through a VQA model and generates answers.

        Args:
            images: Batch of image Variables.
            questions: Batch of question Variables.
            qlengths: List of question lengths.
            answers: Batch of answer Variables.

        Returns:
            - outputs: The output scores for all steps in the RNN.
            - hidden: The hidden states of all the RNNs.
            - ret_dict: A dictionary of attributes. See DecoderRNN.py for details.
        """

        # features is (N * 2048 * 56 * 56)

        input_spatial_dim = images.size()[2:]
        features = self.encoder_cnn.resnet(images)

        # encoder_hidden is ((BIDIRECTIONAL x NUM_LAYERS) * N * HIDDEN_SIZE).
        _, encoder_hidden_ans = self.encoder_rnn(answers, alengths, None)

        if self.encoder_rnn.rnn_cell is nn.LSTM:
            encoder_hidden_ans = encoder_hidden_ans[0]
        encoder_hidden_ans = encoder_hidden_ans.transpose(0, 1).contiguous()

        if self.bidirectional_multiplier == 2:
            encoder_hidden = torch.cat(
                (encoder_hidden_ans[:, 0], encoder_hidden_ans[:, -1]), dim=1)
        else:
            encoder_hidden = encoder_hidden_ans[:, -1]

        if questions is not None:
            alengths = process_lengths(questions)
            # Reorder based on length
            sort_index = sorted(range(len(alengths)),
                                key=lambda x: alengths[x].item(),
                                reverse=True)
            questions = questions[sort_index]
            alengths = np.array(alengths)[sort_index].tolist()
            _, encoder_hidden_qs = self.encoder_rnn(questions, alengths, None)
            if self.encoder_rnn.rnn_cell is nn.LSTM:
                encoder_hidden_qs = encoder_hidden_qs[0]
                encoder_hidden_qs = encoder_hidden_qs.transpose(
                    0, 1).contiguous()

            if self.bidirectional_multiplier == 2:
                encoder_hidden_qs = torch.cat(
                    (encoder_hidden_qs[:, 0], encoder_hidden_qs[:, -1]), dim=1)
            else:
                encoder_hidden_qs = encoder_hidden_qs[:, -1]

            # Reorder to match answer ordering
            ordering = [sort_index.index(i) for i in range(images.size(0))]
            encoder_hidden_qs = encoder_hidden_qs[ordering]
            encoder_hidden = torch.cat([encoder_hidden, encoder_hidden_qs],
                                       dim=1)

        # Pass the features through the stacked attention network.
        encoder_hidden = encoder_hidden.unsqueeze(2).unsqueeze(2).repeat(
            1, 1, features.size(2), features.size(3))
        features = self.encoder_cnn.fc(features * encoder_hidden)
        result = nn.functional.upsample_bilinear(input=features,
                                                 size=input_spatial_dim)

        return result
Example #11
0
def main(args):

    # Setting up seeds.
    torch.cuda.manual_seed(args.seed)
    torch.manual_seed(args.seed)

    # Create model directory.
    if not os.path.exists(args.model_dir):
        os.makedirs(args.model_dir)

    # Config logging.
    log_format = '%(levelname)-8s %(message)s'
    logfile = os.path.join(args.model_dir, 'train.log')
    logging.basicConfig(filename=logfile,
                        level=logging.INFO,
                        format=log_format)
    logging.getLogger().addHandler(logging.StreamHandler())
    logging.info(json.dumps(args.__dict__))

    # Save the arguments.
    with open(os.path.join(args.model_dir, 'args.json'), 'w') as args_file:
        json.dump(args.__dict__, args_file)

    # Load vocabulary wrapper.
    vocab = load_vocab(args.vocab_path)
    vocab.top_answers = json.load(open(args.top_answers))

    # Build data loader.
    logging.info("Building data loader...")
    data_loader = get_vqa_loader(args.dataset,
                                 args.batch_size,
                                 shuffle=True,
                                 num_workers=args.num_workers,
                                 max_examples=args.max_examples)

    val_data_loader = get_vqa_loader(args.val_dataset,
                                     args.batch_size,
                                     shuffle=True,
                                     num_workers=args.num_workers)
    logging.info("Done")

    # Build the models
    logging.info("Building MultiSAVQA models...")
    vqa = MultiSAVQAModel(len(vocab),
                          args.max_length,
                          args.hidden_size,
                          args.vocab_embed_size,
                          num_layers=args.num_layers,
                          rnn_cell=args.rnn_cell,
                          bidirectional=args.bidirectional,
                          input_dropout_p=args.dropout,
                          dropout_p=args.dropout,
                          num_att_layers=args.num_att_layers,
                          att_ff_size=args.att_ff_size)
    logging.info("Done")

    if torch.cuda.is_available():
        vqa.cuda()

    # Loss and Optimizer.
    criterion = nn.CrossEntropyLoss()
    if torch.cuda.is_available():
        criterion.cuda()

    # Parameters to train.
    params = vqa.params_to_train()
    optimizer = torch.optim.Adam(params, lr=args.learning_rate)
    scheduler = ReduceLROnPlateau(optimizer=optimizer,
                                  mode='min',
                                  factor=0.1,
                                  patience=args.patience,
                                  verbose=True,
                                  min_lr=1e-6)

    # Train the Models.
    total_steps = len(data_loader) * args.num_epochs
    start_time = time.time()
    n_steps = 0
    for epoch in range(args.num_epochs):
        for i, (feats, questions, categories) in enumerate(data_loader):
            n_steps += 1

            # Set mini-batch dataset.
            if torch.cuda.is_available():
                feats = feats.cuda()
                questions = questions.cuda()
                categories = categories.cuda()
            qlengths = process_lengths(questions)

            # Forward.
            vqa.train()
            vqa.zero_grad()
            outputs = vqa(feats, questions, qlengths)

            # Calculate the loss.
            loss = criterion(outputs, categories)

            # Backprop and optimize.
            loss.backward()
            optimizer.step()

            # Eval now.
            if (args.eval_every_n_steps is not None
                    and n_steps >= args.eval_every_n_steps
                    and n_steps % args.eval_every_n_steps == 0):
                logging.info('=' * 100)
                val_loss = evaluate(vocab, vqa, val_data_loader, criterion,
                                    epoch, args)
                scheduler.step(val_loss)
                logging.info('=' * 100)

            # Take argmax for each timestep
            preds = outputs.max(1)[1]
            score = accuracy(preds, categories)

            # Print log info.
            if i % args.log_step == 0:
                delta_time = time.time() - start_time
                start_time = time.time()
                logging.info('Time: %.4f, Epoch [%d/%d], Step [%d/%d], '
                             'Accuracy: %.4f, Loss: %.4f, LR: %f' %
                             (delta_time, epoch + 1,
                              args.num_epochs, n_steps, total_steps, score,
                              loss.item(), optimizer.param_groups[0]['lr']))

            # Save the models.
            if (i + 1) % args.save_step == 0:
                torch.save(
                    vqa.state_dict(),
                    os.path.join(args.model_dir,
                                 'multi-savqa-%d-%d.pkl' % (epoch + 1, i + 1)))

        torch.save(
            vqa.state_dict(),
            os.path.join(args.model_dir, 'multi-savqa-%d.pkl' % (epoch + 1)))

        # Evaluation and learning rate updates.
        logging.info('=' * 100)
        val_loss = evaluate(vocab, vqa, val_data_loader, criterion, epoch,
                            args)
        scheduler.step(val_loss)
        logging.info('=' * 100)

    # Save the final model.
    torch.save(vqa.state_dict(), os.path.join(args.model_dir, 'vqa.pkl'))
Example #12
0
def train(args):
    # Create model directory
    if not os.path.exists(args.model_path):
        os.makedirs(args.model_path)

    # Save the arguments.
    with open(os.path.join(args.model_path, 'args.json'), 'w') as args_file:
        json.dump(args.__dict__, args_file)

    # Config logging.
    log_format = '%(levelname)-8s %(message)s'
    logfile = os.path.join(args.model_path, 'train.log')
    logging.basicConfig(filename=logfile,
                        level=logging.INFO,
                        format=log_format)
    logging.getLogger().addHandler(logging.StreamHandler())
    logging.info(json.dumps(args.__dict__))

    vocab = load_vocab(args.vocab_path)

    # Build data loader
    logging.info("Building data loader...")
    train_sampler = None
    val_sampler = None

    if os.path.exists(args.train_dataset_weights):
        train_weights = json.load(open(args.train_dataset_weights))
        train_weights = torch.DoubleTensor(train_weights)
        train_sampler = torch.utils.data.sampler.WeightedRandomSampler(
            train_weights, len(train_weights))
    if os.path.exists(args.val_dataset_weights):
        val_weights = json.load(open(args.val_dataset_weights))
        val_weights = torch.DoubleTensor(val_weights)
        val_sampler = torch.utils.data.sampler.WeightedRandomSampler(
            val_weights, len(val_weights))
    data_loader = get_loader(args.dataset,
                             args.batch_size,
                             shuffle=True,
                             num_workers=args.num_workers,
                             max_examples=args.max_examples,
                             sampler=train_sampler)
    val_data_loader = get_loader(args.val_dataset,
                                 args.batch_size,
                                 shuffle=False,
                                 num_workers=args.num_workers,
                                 max_examples=args.max_examples,
                                 sampler=val_sampler)
    logging.info("Done")

    model = create_model(args, vocab)

    if args.load_model is not None:
        model.load_state_dict(torch.load(args.load_model))
    logging.info("Done")

    criterion = nn.CrossEntropyLoss()

    #criterion = nn.BCELoss()
    # Setup GPUs.
    if torch.cuda.is_available():
        logging.info("Using available GPU...")
        model.cuda()
        criterion.cuda()
        torch.backends.cudnn.enabled = True
        cudnn.benchmark = True
    # Parameters to train.
    params = model.parameters()

    learning_rate = args.learning_rate
    optimizer = torch.optim.Adam(params, lr=learning_rate)
    scheduler = ReduceLROnPlateau(optimizer=optimizer,
                                  mode='min',
                                  factor=0.1,
                                  patience=args.patience,
                                  verbose=True,
                                  min_lr=1e-7)

    # Train the model.
    total_steps = len(data_loader)
    start_time = time.time()
    n_steps = 0

    #initialize every epoch
    epoch_loss = 0
    epoch_acc = 0

    #set the model in training phase

    for epoch in range(args.num_epochs):
        epoch_loss = 0
        epoch_acc = 0
        model.train()
        for i, (sentences, labels, qindices) in enumerate(data_loader):
            n_steps += 1

            # Set mini-batch dataset.
            if torch.cuda.is_available():
                sentences = sentences.cuda()
                labels = labels.cuda()
                qindices = qindices.cuda()
            lengths = process_lengths(sentences)
            lengths.sort(reverse=True)
            #resets the gradients after every batch
            optimizer.zero_grad()
            #convert to 1D tensor
            predictions = model(sentences, lengths)

            loss = criterion(predictions, labels)
            #compute the binary accuracy
            acc = binary_accuracy(predictions, labels)
            #backpropage the loss and compute the gradients
            loss.backward()

            #update the weights
            optimizer.step()
            #loss and accuracy
            epoch_loss += loss.item()
            epoch_acc += acc.item()
        delta_time = time.time() - start_time
        start_time = time.time()
        logging.info('Epoch [%d/%d] | Step [%d/%d] | Time: %.4f  \n'
                     '\t Train-Loss: %.4f | Train-Acc: %.2f' %
                     (epoch, args.num_epochs, i, total_steps, delta_time,
                      loss.item(), acc.item() * 100))
        evaluate(model, val_data_loader, criterion, args)
        torch.save(
            model.state_dict(),
            os.path.join(args.model_path, 'model-tf-%d.pkl' % (epoch + 1)))