def validate(val_loader, decoder, criterion_ce, i2w, device, print_freq, word_map, current_epoch, break_flag, top_x, smoothing_method, print_flag): """ Performs one epoch's validation. :param val_loader: DataLoader for validation data. :param decoder: decoder model :param criterion_ce: cross entropy loss layer :param criterion_dis : discriminative loss layer :return: BLEU-4 score """ decoder.eval() # eval mode (no dropout or batchnorm) batch_time = AverageMeter() losses = AverageMeter() top5accs = AverageMeter() start = time.time() references = list() # references (true captions) for calculating BLEU-4 score hypotheses = list() # hypotheses (predictions) # Batches with torch.no_grad(): for i, data in enumerate(val_loader): if break_flag and i == 5: break # only 5 batches print('val i', i) imgs, caps, caplens, allcaps = data # Move to device, if available imgs = imgs.to(device) caps = caps.to(device) caplens = caplens.to(device) scores, caps_sorted, decode_lengths, sort_ind = decoder(imgs, caps, caplens) # Since we decoded starting with <start>, the targets are all words after <start>, up to <end> targets = caps_sorted[:, 1:] if print_flag: print_predictions(scores, targets, i2w) # Remove timesteps that we didn't decode at, or are pads # pack_padded_sequence is an easy trick to do this scores_copy = scores.clone() scores, _ = pack_padded_sequence(scores, decode_lengths, batch_first=True) targets, _ = pack_padded_sequence(targets, decode_lengths, batch_first=True) # Calculate loss loss = criterion_ce(scores, targets) # Keep track of metrics losses.update(loss.item(), sum(decode_lengths)) top5 = accuracy(scores, targets, top_x) top5accs.update(top5, sum(decode_lengths)) batch_time.update(time.time() - start) start = time.time() if i % print_freq == 0: print('Validation: [{0}/{1}]\t' 'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 'Top-{topx} Accuracy {top5.val:.3f} ({top5.avg:.3f})\t'.format(i, len(val_loader), batch_time=batch_time, loss=losses, topx=top_x, top5=top5accs)) # Store references (true captions), and hypothesis (prediction) for each image # If for n images, we have n hypotheses, and references a, b, c... for each image, we need - # references = [[ref1a, ref1b, ref1c], [ref2a, ref2b], ...], hypotheses = [hyp1, hyp2, ...] # References allcaps = allcaps[sort_ind] # because images were sorted in the decoder # DIDEC caps of other participants come here # print(allcaps.shape) for j in range(allcaps.shape[0]): img_caps = allcaps[j].tolist() img_captions = list( map(lambda c: [w for w in c if w not in {word_map['<start>'], word_map['<pad>']}], img_caps)) # remove <start> and pads refs_per_img = [] for ic in img_captions: if len(ic) > 0: refs_per_img.append(ic) references.append(refs_per_img) # Hypotheses _, preds = torch.max(scores_copy, dim=2) # print('scr', scores_copy) # print('preds', preds) preds = preds.tolist() temp_preds = list() for j, p in enumerate(preds): temp_preds.append(preds[j][:decode_lengths[j]]) # remove pads #SUPERFLUOUS ENDS stay for teacher forcing preds = temp_preds hypotheses.extend(preds) assert len(references) == len(hypotheses) # Calculate BLEU-4 scores #print('refshyps') #print(references) #print(hypotheses) bleu4 = corpus_bleu(references, hypotheses, smoothing_function=smoothing_method) bleu4 = round(bleu4,4) print('\n * LOSS - {loss.avg:.3f}, TOP-{topx} ACCURACY - {top5.avg:.3f}, BLEU-4 - {bleu}\n'.format( loss=losses, topx=top_x, top5=top5accs, bleu=bleu4)) return bleu4, losses
print('Epoch', epoch) print('Train') # Decay learning rate if there is no improvement for 5 consecutive epochs # halved # Terminate training after 20 if epochs_since_improvement == 20: break #if epochs_since_improvement > 0 and epochs_since_improvement % 10 == 0: # adjust_learning_rate(decoder_optimizer, 0.5) decoder.train() torch.enable_grad() batch_time = AverageMeter() # forward prop. + back prop. time data_time = AverageMeter() # data loading time losses = AverageMeter() # loss (per word decoded) top5accs = AverageMeter() # top5 accuracy start = time.time() count = 0 for i, data in enumerate(training_loader): if break_flag and count == 1: break count += 1