def evaluate(vqa, data_loader, criterion, l2_criterion, args): """Calculates vqa average loss on data_loader. Args: vqa: visual question answering model. data_loader: Iterator for the data. criterion: The loss function used to evaluate the loss. l2_criterion: The loss function used to evaluate the l2 loss. args: ArgumentParser object. Returns: A float value of average loss. """ vqa.eval() total_gen_loss = 0.0 total_kl = 0.0 total_recon_image_loss = 0.0 total_recon_question_loss = 0.0 total_z_t_kl = 0.0 total_t_kl_loss = 0.0 total_steps = len(data_loader) if args.eval_steps is not None: total_steps = min(len(data_loader), args.eval_steps) start_time = time.time() for iterations, (images, questions, answers, aindices) in enumerate(data_loader): # Set mini-batch dataset if torch.cuda.is_available(): images = images.cuda() questions = questions.cuda() answers = answers.cuda() aindices = aindices.cuda() qlengths = process_lengths(questions) qlengths.sort(reverse=True) # Forward, Backward and Optimize image_features = vqa.encode_images(images) question_features = vqa.encode_questions(questions, qlengths) mus, logvars = vqa.encode_into_z(image_features, question_features) zs = vqa.reparameterize(mus, logvars) (outputs, _, other) = vqa.decode_answers(image_features, zs, answers=answers, teacher_forcing_ratio=1.0) # Reorder the questions based on length. answers = torch.index_select(answers, 0, aindices) # Ignoring the start token. answers = answers[:, 1:] alengths = process_lengths(answers) alengths.sort(reverse=True) # Convert the output from MAX_LEN list of (BATCH x VOCAB) -> # (BATCH x MAX_LEN x VOCAB). outputs = [o.unsqueeze(1) for o in outputs] outputs = torch.cat(outputs, dim=1) outputs = torch.index_select(outputs, 0, aindices) # Calculate the loss. targets = pack_padded_sequence(answers, alengths, batch_first=True)[0] outputs = pack_padded_sequence(outputs, alengths, batch_first=True)[0] gen_loss = criterion(outputs, targets) total_gen_loss += gen_loss.data.item() # Get KL loss if it exists. kl_loss = gaussian_KL_loss(mus, logvars) total_kl += kl_loss.item() # Reconstruction. if not args.no_image_recon or not args.no_question_recon: image_targets = image_features.detach() question_targets = question_features.detach() recon_image_features, recon_question_features = vqa.reconstruct_inputs( image_targets, question_targets) if not args.no_image_recon: recon_i_loss = l2_criterion(recon_image_features, image_targets) total_recon_image_loss += recon_i_loss.item() if not args.no_question_recon: recon_q_loss = l2_criterion(recon_question_features, question_targets) total_recon_question_loss += recon_q_loss.item() # Quit after eval_steps. if args.eval_steps is not None and iterations >= args.eval_steps: break # Print logs if iterations % args.log_step == 0: delta_time = time.time() - start_time start_time = time.time() logging.info('Time: %.4f, Step [%d/%d], gen loss: %.4f, ' 'KL: %.4f, I-recon: %.4f, Q-recon: %.4f' % (delta_time, iterations, total_steps, total_gen_loss / (iterations + 1), total_kl / (iterations + 1), total_recon_image_loss / (iterations + 1), total_recon_question_loss / (iterations + 1))) total_info_loss = total_recon_image_loss + total_recon_question_loss return total_gen_loss / (iterations + 1), total_info_loss / (iterations + 1)
def evaluate(vqg, data_loader, criterion, l2_criterion, args): vqg.eval() if (args.bayes): alpha = vqg.alpha total_gen_loss = 0.0 total_kl = 0.0 total_recon_image_loss = 0.0 total_recon_category_loss = 0.0 total_z_t_kl = 0.0 total_t_kl_loss = 0.0 regularisation_loss = 0.0 c_loss = 0.0 category_cycle_loss = 0.0 total_steps = len(data_loader) if args.eval_steps is not None: total_steps = min(len(data_loader), args.eval_steps) start_time = time.time() for iterations, (images, questions, answers, categories, qindices) in enumerate(data_loader): ''' remove answers from the dataloader later ''' # Set mini-batch dataset if torch.cuda.is_available(): images = images.cuda() questions = questions.cuda() answers = answers.cuda() categories = categories.cuda() qindices = qindices.cuda() if (args.bayes): alpha = vqg.alpha.cuda() # Forward, Backward and Optimize image_features = vqg.encode_images(images) category_features = vqg.encode_categories(categories) t_mus, t_logvars, ts = vqg.encode_into_t(image_features, category_features) (outputs, _, other), pred_ques = vqg.decode_questions(image_features, ts, questions=questions, teacher_forcing_ratio=1.0) # Reorder the questions based on length. questions = torch.index_select(questions, 0, qindices) total_loss = 0.0 # Ignoring the start token. questions = questions[:, 1:] qlengths = process_lengths(questions) # Convert the output from MAX_LEN list of (BATCH x VOCAB) -> # (BATCH x MAX_LEN x VOCAB). outputs = [o.unsqueeze(1) for o in outputs] outputs = torch.cat(outputs, dim=1) outputs = torch.index_select(outputs, 0, qindices) if (args.step_two): category_cycle = vqg.encode_questions(pred_ques, qlengths) category_cycle_loss = criterion(category_cycle, categories) category_cycle_loss = category_cycle_loss.item() total_loss += args.lambda_c_cycle * category_cycle_loss # Calculate the loss. targets = pack_padded_sequence(questions, qlengths, batch_first=True)[0] outputs = pack_padded_sequence(outputs, qlengths, batch_first=True)[0] gen_loss = criterion(outputs, targets) total_gen_loss += gen_loss.data.item() # Get KL loss if it exists. if (args.bayes): regularisation_loss = l2_criterion(alpha.pow(-1), torch.ones_like(alpha)) kl_loss = -0.5 * torch.sum(1 + t_logvars + alpha.pow(2).log() - alpha.pow(2) * (t_mus.pow(2) + t_logvars.exp())) total_kl += kl_loss.item() + regularisation_loss.item() else: kl_loss = gaussian_KL_loss(t_mus, t_logvars) total_kl += args.lambda_t * kl_loss kl_loss = kl_loss.item() # Reconstruction. if not args.no_image_recon or not args.no_category_space: image_targets = image_features.detach() category_targets = category_features.detach() recon_image_features, recon_category_features = vqg.reconstruct_inputs( image_targets, category_targets) if not args.no_image_recon: recon_i_loss = l2_criterion(recon_image_features, image_targets) total_recon_image_loss += recon_i_loss.item() if not args.no_category_space: recon_c_loss = l2_criterion(recon_category_features, category_targets) total_recon_category_loss += recon_c_loss.item() # Quit after eval_steps. if args.eval_steps is not None and iterations >= args.eval_steps: break # Print logs if iterations % args.log_step == 0: delta_time = time.time() - start_time start_time = time.time() logging.info( 'Time: %.4f, Step [%d/%d], gen loss: %.4f, ' 'KL: %.4f, I-recon: %.4f, C-recon: %.4f, C-cycle: %.4f, Regularisation: %.4f' % (delta_time, iterations, total_steps, total_gen_loss / (iterations + 1), total_kl / (iterations + 1), total_recon_image_loss / (iterations + 1), total_recon_category_loss / (iterations + 1), category_cycle_loss / (iterations + 1), regularisation_loss / (iterations + 1))) total_info_loss = total_recon_image_loss + total_recon_category_loss return total_gen_loss / (iterations + 1), total_info_loss / (iterations + 1)
def train(args): # Create model directory if not os.path.exists(args.model_path): os.makedirs(args.model_path) # Save the arguments. with open(os.path.join(args.model_path, 'args.json'), 'w') as args_file: json.dump(args.__dict__, args_file) # Config logging. log_format = '%(levelname)-8s %(message)s' logfile = os.path.join(args.model_path, 'train.log') logging.basicConfig(filename=logfile, level=logging.INFO, format=log_format) logging.getLogger().addHandler(logging.StreamHandler()) logging.info(json.dumps(args.__dict__)) # Image preprocessing transform = transforms.Compose([ transforms.ToTensor(), transforms.ToPILImage(), transforms.RandomResizedCrop(args.crop_size, scale=(1.00, 1.2), ratio=(0.75, 1.3333333333333333)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # Load vocabulary wrapper. vocab = load_vocab(args.vocab_path) # Build data loader logging.info("Building data loader...") train_sampler = None val_sampler = None if os.path.exists(args.train_dataset_weights): train_weights = json.load(open(args.train_dataset_weights)) train_weights = torch.DoubleTensor(train_weights) train_sampler = torch.utils.data.sampler.WeightedRandomSampler( train_weights, len(train_weights)) if os.path.exists(args.val_dataset_weights): val_weights = json.load(open(args.val_dataset_weights)) val_weights = torch.DoubleTensor(val_weights) val_sampler = torch.utils.data.sampler.WeightedRandomSampler( val_weights, len(val_weights)) data_loader = get_loader(args.dataset, transform, args.batch_size, shuffle=True, num_workers=args.num_workers, max_examples=args.max_examples, sampler=train_sampler) val_data_loader = get_loader(args.val_dataset, transform, args.batch_size, shuffle=False, num_workers=args.num_workers, max_examples=args.max_examples, sampler=val_sampler) logging.info("Done") vqa = create_model(args, vocab) if args.load_model is not None: vqa.load_state_dict(torch.load(args.load_model)) logging.info("Done") # Loss criterion. pad = vocab(vocab.SYM_PAD) # Set loss weight for 'pad' symbol to 0 criterion = nn.CrossEntropyLoss(ignore_index=pad) l2_criterion = nn.MSELoss() # Setup GPUs. if torch.cuda.is_available(): logging.info("Using available GPU...") vqa.cuda() criterion.cuda() l2_criterion.cuda() # Parameters to train. gen_params = vqa.generator_parameters() info_params = vqa.info_parameters() learning_rate = args.learning_rate info_learning_rate = args.info_learning_rate gen_optimizer = torch.optim.Adam(gen_params, lr=learning_rate) info_optimizer = torch.optim.Adam(info_params, lr=info_learning_rate) scheduler = ReduceLROnPlateau(optimizer=gen_optimizer, mode='min', factor=0.1, patience=args.patience, verbose=True, min_lr=1e-7) info_scheduler = ReduceLROnPlateau(optimizer=info_optimizer, mode='min', factor=0.1, patience=args.patience, verbose=True, min_lr=1e-7) # Train the model. total_steps = len(data_loader) start_time = time.time() n_steps = 0 # Optional losses. Initialized here for logging. recon_question_loss = 0.0 recon_image_loss = 0.0 kl_loss = 0.0 z_t_kl = 0.0 t_kl = 0.0 for epoch in range(args.num_epochs): for i, (images, questions, answers, aindices) in enumerate(data_loader): n_steps += 1 # Set mini-batch dataset. if torch.cuda.is_available(): images = images.cuda() questions = questions.cuda() answers = answers.cuda() aindices = aindices.cuda() qlengths = process_lengths(questions) qlengths.sort(reverse=True) # Eval now. if (args.eval_every_n_steps is not None and n_steps >= args.eval_every_n_steps and n_steps % args.eval_every_n_steps == 0): run_eval(vqa, val_data_loader, criterion, l2_criterion, args, epoch, scheduler, info_scheduler) compare_outputs(images, answers, questions, qlengths, vqa, vocab, logging, args) # Forward. vqa.train() gen_optimizer.zero_grad() info_optimizer.zero_grad() image_features = vqa.encode_images(images) question_features = vqa.encode_questions(questions, qlengths) # Question generation. mus, logvars = vqa.encode_into_z(image_features, question_features) zs = vqa.reparameterize(mus, logvars) (outputs, _, _) = vqa.decode_answers(image_features, zs, answers=answers, teacher_forcing_ratio=1.0) # Reorder the questions based on length. answers = torch.index_select(answers, 0, aindices) # Ignoring the start token. answers = answers[:, 1:] alengths = process_lengths(answers) alengths.sort(reverse=True) # Convert the output from MAX_LEN list of (BATCH x VOCAB) -> # (BATCH x MAX_LEN x VOCAB). outputs = [o.unsqueeze(1) for o in outputs] outputs = torch.cat(outputs, dim=1) outputs = torch.index_select(outputs, 0, aindices) # Calculate the generation loss. targets = pack_padded_sequence(answers, alengths, batch_first=True)[0] outputs = pack_padded_sequence(outputs, alengths, batch_first=True)[0] gen_loss = criterion(outputs, targets) total_loss = 0.0 total_loss += args.lambda_gen * gen_loss gen_loss = gen_loss.item() # Variational loss. kl_loss = gaussian_KL_loss(mus, logvars) total_loss += args.lambda_z * kl_loss kl_loss = kl_loss.item() # Generator Backprop. total_loss.backward() gen_optimizer.step() # Reconstruction loss. recon_image_loss = 0.0 recon_question_loss = 0.0 if not args.no_question_recon or not args.no_image_recon: total_info_loss = 0.0 gen_optimizer.zero_grad() info_optimizer.zero_grad() question_targets = question_features.detach() image_targets = image_features.detach() recon_image_features, recon_question_features = vqa.reconstruct_inputs( image_targets, question_targets) # Answer reconstruction loss. if not args.no_question_recon: recon_q_loss = l2_criterion(recon_question_features, question_targets) total_info_loss += args.lambda_a * recon_q_loss recon_question_loss = recon_q_loss.item() # Image reconstruction loss. if not args.no_image_recon: recon_i_loss = l2_criterion(recon_image_features, image_targets) total_info_loss += args.lambda_i * recon_i_loss recon_image_loss = recon_i_loss.item() # Info backprop. total_info_loss.backward() info_optimizer.step() # Print log info if i % args.log_step == 0: delta_time = time.time() - start_time start_time = time.time() logging.info( 'Time: %.4f, Epoch [%d/%d], Step [%d/%d], ' 'LR: %f, gen: %.4f, KL: %.4f, ' 'I-recon: %.4f, Q-recon: %.4f' % (delta_time, epoch, args.num_epochs, i, total_steps, gen_optimizer.param_groups[0]['lr'], gen_loss, kl_loss, recon_image_loss, recon_question_loss)) # Save the models if args.save_step is not None and (i + 1) % args.save_step == 0: torch.save( vqa.state_dict(), os.path.join(args.model_path, 'vqa-tf-%d-%d.pkl' % (epoch + 1, i + 1))) torch.save( vqa.state_dict(), os.path.join(args.model_path, 'vqa-tf-%d.pkl' % (epoch + 1))) # Evaluation and learning rate updates. run_eval(vqa, val_data_loader, criterion, l2_criterion, args, epoch, scheduler, info_scheduler)
def train(args): # Create model directory if not os.path.exists(args.model_path): os.makedirs(args.model_path) # Save the arguments. with open(os.path.join(args.model_path, 'args.json'), 'w') as args_file: json.dump(args.__dict__, args_file) # Config logging. log_format = '%(levelname)-8s %(message)s' log_file_name = 'train_' + args.train_log_file_suffix + '.log' logfile = os.path.join(args.model_path, log_file_name) logging.basicConfig(filename=logfile, level=logging.INFO, format=log_format) logging.getLogger().addHandler(logging.StreamHandler()) logging.info(json.dumps(args.__dict__)) # Image preprocessing transform = transforms.Compose([ transforms.ToTensor(), transforms.ToPILImage(), transforms.RandomResizedCrop(args.crop_size, scale=(1.00, 1.2), ratio=(0.75, 1.3333333333333333)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # Load vocabulary wrapper. vocab = load_vocab(args.vocab_path) # Load the category types. cat2name = json.load(open(args.cat2name)) # Build data loader logging.info("Building data loader...") train_sampler = None val_sampler = None if os.path.exists(args.train_dataset_weights): train_weights = json.load(open(args.train_dataset_weights)) train_weights = torch.DoubleTensor(train_weights) train_sampler = torch.utils.data.sampler.WeightedRandomSampler( train_weights, len(train_weights)) if os.path.exists(args.val_dataset_weights): val_weights = json.load(open(args.val_dataset_weights)) val_weights = torch.DoubleTensor(val_weights) val_sampler = torch.utils.data.sampler.WeightedRandomSampler( val_weights, len(val_weights)) data_loader = get_loader(args.dataset, transform, args.batch_size, shuffle=False, num_workers=args.num_workers, max_examples=args.max_examples, sampler=train_sampler) val_data_loader = get_loader(args.val_dataset, transform, args.batch_size, shuffle=False, num_workers=args.num_workers, max_examples=args.max_examples, sampler=val_sampler) print('Done loading data ............................') logging.info("Done") vqg = create_model(args, vocab) if args.load_model is not None: vqg.load_state_dict(torch.load(args.load_model)) logging.info("Done") # Loss criterion. pad = vocab(vocab.SYM_PAD) # Set loss weight for 'pad' symbol to 0 criterion = nn.CrossEntropyLoss() criterion2 = nn.MultiMarginLoss().cuda() l2_criterion = nn.MSELoss() alpha = None if (args.bayes): alpha = vqg.alpha # Setup GPUs. if torch.cuda.is_available(): logging.info("Using available GPU...") vqg.cuda() criterion.cuda() l2_criterion.cuda() if (alpha is not None): alpha.cuda() gen_params = vqg.generator_parameters() info_params = vqg.info_parameters() learning_rate = args.learning_rate info_learning_rate = args.info_learning_rate gen_optimizer = torch.optim.Adam(gen_params, lr=learning_rate) info_optimizer = torch.optim.Adam(info_params, lr=info_learning_rate) if (args.step_two): cycle_params = vqg.cycle_params() cycle_optimizer = torch.optim.Adam(cycle_params, lr=learning_rate) if (args.center_loss): center_loss = CenterLoss(num_classes=args.num_categories, feat_dim=args.z_size, use_gpu=True) optimizer_centloss = torch.optim.SGD(center_loss.parameters(), lr=0.5) scheduler = ReduceLROnPlateau(optimizer=gen_optimizer, mode='min', factor=0.5, patience=args.patience, verbose=True, min_lr=1e-7) cycle_scheduler = ReduceLROnPlateau(optimizer=gen_optimizer, mode='min', factor=0.99, patience=args.patience, verbose=True, min_lr=1e-7) info_scheduler = ReduceLROnPlateau(optimizer=info_optimizer, mode='min', factor=0.5, patience=args.patience, verbose=True, min_lr=1e-7) # Train the model. total_steps = len(data_loader) start_time = time.time() n_steps = 0 # Optional losses. Initialized here for logging. recon_category_loss = 0.0 recon_image_loss = 0.0 kl_loss = 0.0 category_cycle_loss = 0.0 regularisation_loss = 0.0 c_loss = 0.0 cycle_loss = 0.0 if (args.step_two): category_cycle_loss = 0.0 if (args.bayes): regularisation_loss = 0.0 if (args.center_loss): loss_center = 0.0 c_loss = 0.0 for epoch in range(args.num_epochs): for i, (images, questions, answers, categories, qindices) in enumerate(data_loader): n_steps += 1 ''' remove answers from dataloader later ''' # Set mini-batch dataset. if torch.cuda.is_available(): images = images.cuda() questions = questions.cuda() answers = answers.cuda() categories = categories.cuda() qindices = qindices.cuda() if (args.bayes): alpha = alpha.cuda() # Eval now. if (args.eval_every_n_steps is not None and n_steps >= args.eval_every_n_steps and n_steps % args.eval_every_n_steps == 0): run_eval(vqg, val_data_loader, criterion, l2_criterion, args, epoch, scheduler, info_scheduler) compare_outputs(images, questions, answers, categories, vqg, vocab, logging, cat2name, args) # Forward. vqg.train() gen_optimizer.zero_grad() info_optimizer.zero_grad() if (args.step_two): cycle_optimizer.zero_grad() if (args.center_loss): optimizer_centloss.zero_grad() image_features = vqg.encode_images(images) category_features = vqg.encode_categories(categories) # Question generation. t_mus, t_logvars, ts = vqg.encode_into_t(image_features, category_features) if (args.center_loss): loss_center = 0.0 c_loss = center_loss(ts, categories) loss_center += args.lambda_centerloss * c_loss c_loss = c_loss.item() loss_center.backward(retain_graph=True) for param in center_loss.parameters(): param.grad.data *= (1. / args.lambda_centerloss) optimizer_centloss.step() qlengths_prev = process_lengths(questions) (outputs, _, _), pred_ques = vqg.decode_questions(image_features, ts, questions=questions, teacher_forcing_ratio=1.0) # Reorder the questions based on length. questions = torch.index_select(questions, 0, qindices) # Ignoring the start token. questions = questions[:, 1:] qlengths = process_lengths(questions) # Convert the output from MAX_LEN list of (BATCH x VOCAB) -> # (BATCH x MAX_LEN x VOCAB). outputs = [o.unsqueeze(1) for o in outputs] outputs = torch.cat(outputs, dim=1) outputs = torch.index_select(outputs, 0, qindices) if (args.step_two): category_cycle_loss = 0.0 category_cycle = vqg.encode_questions(pred_ques, qlengths) cycle_loss = criterion(category_cycle, categories) category_cycle_loss += args.lambda_c_cycle * cycle_loss cycle_loss = cycle_loss.item() category_cycle_loss.backward(retain_graph=True) cycle_optimizer.step() # Calculate the generation loss. targets = pack_padded_sequence(questions, qlengths, batch_first=True)[0] outputs = pack_padded_sequence(outputs, qlengths, batch_first=True)[0] gen_loss = criterion(outputs, targets) total_loss = 0.0 total_loss += args.lambda_gen * gen_loss gen_loss = gen_loss.item() # Variational loss. if (args.bayes): kl_loss = -0.5 * torch.sum(1 + t_logvars + alpha.pow(2).log() - alpha.pow(2) * (t_mus.pow(2) + t_logvars.exp())) regularisation_loss = l2_criterion(alpha.pow(-1), torch.ones_like(alpha)) total_loss += args.lambda_t * kl_loss + args.lambda_reg * regularisation_loss kl_loss = kl_loss.item() regularisation_loss = regularisation_loss.item() else: kl_loss = gaussian_KL_loss(t_mus, t_logvars) total_loss += args.lambda_t * kl_loss kl_loss = kl_loss.item() # Generator Backprop. total_loss.backward(retain_graph=True) gen_optimizer.step() # Reconstruction loss. recon_image_loss = 0.0 recon_category_loss = 0.0 if not args.no_category_space or not args.no_image_recon: total_info_loss = 0.0 category_targets = category_features.detach() image_targets = image_features.detach() recon_image_features, recon_category_features = vqg.reconstruct_inputs( image_targets, category_targets) # Category reconstruction loss. if not args.no_category_space: recon_c_loss = l2_criterion( recon_category_features, category_targets) # changed to criterion2 total_info_loss += args.lambda_c * recon_c_loss recon_category_loss = recon_c_loss.item() # Image reconstruction loss. if not args.no_image_recon: recon_i_loss = l2_criterion(recon_image_features, image_targets) total_info_loss += args.lambda_i * recon_i_loss recon_image_loss = recon_i_loss.item() # Info backprop. total_info_loss.backward() info_optimizer.step() # Print log info if i % args.log_step == 0: delta_time = time.time() - start_time start_time = time.time() logging.info( 'Time: %.4f, Epoch [%d/%d], Step [%d/%d], ' 'LR: %f, Center-Loss: %.4f, KL: %.4f, ' 'I-recon: %.4f, C-recon: %.4f, C-cycle: %.4f, Regularisation: %.4f' % (delta_time, epoch, args.num_epochs, i, total_steps, gen_optimizer.param_groups[0]['lr'], c_loss, kl_loss, recon_image_loss, recon_category_loss, cycle_loss, regularisation_loss)) # Save the models if args.save_step is not None and (i + 1) % args.save_step == 0: torch.save( vqg.state_dict(), os.path.join(args.model_path, 'vqg-tf-%d-%d.pkl' % (epoch + 1, i + 1))) torch.save( vqg.state_dict(), os.path.join(args.model_path, 'vqg-tf-%d.pkl' % (epoch + 1))) torch.save( center_loss.state_dict(), os.path.join(args.model_path, 'closs-tf-%d-%d.pkl' % (epoch + 1, i + 1))) # Evaluation and learning rate updates. run_eval(vqg, val_data_loader, criterion, l2_criterion, args, epoch, scheduler, info_scheduler)