def main(args): # Create model directory if not os.path.exists(args.model_path): os.makedirs(args.model_path) # Image preprocessing train_transform = transforms.Compose([ transforms.RandomCrop(args.image_size), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) ]) val_transform = transforms.Compose([ transforms.Resize(args.image_size, interpolation=Image.LANCZOS), transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) ]) # Load vocabulary wrapper. with open(args.vocab_path, 'rb') as f: vocab = pickle.load(f) # Build data loader train_data_loader = get_loader(args.train_image_dir, args.train_sis_path, vocab, train_transform, args.batch_size, shuffle=True, num_workers=args.num_workers) val_data_loader = get_loader(args.val_image_dir, args.val_sis_path, vocab, val_transform, args.batch_size, shuffle=False, num_workers=args.num_workers) encoder = EncoderStory(args.img_feature_size, args.hidden_size, args.num_layers) decoder = DecoderStory(args.embed_size, args.hidden_size, vocab) pretrained_epoch = 0 if args.pretrained_epoch > 0: pretrained_epoch = args.pretrained_epoch encoder.load_state_dict( torch.load('./models/encoder-' + str(pretrained_epoch) + '.pkl')) decoder.load_state_dict( torch.load('./models/decoder-' + str(pretrained_epoch) + '.pkl')) if torch.cuda.is_available(): encoder.cuda() decoder.cuda() print("Cuda is enabled...") criterion = nn.CrossEntropyLoss() params = decoder.get_params() + encoder.get_params() optimizer = torch.optim.Adam(params, lr=args.learning_rate, weight_decay=args.weight_decay) total_train_step = len(train_data_loader) total_val_step = len(val_data_loader) min_avg_loss = float("inf") overfit_warn = 0 for epoch in range(args.num_epochs): if epoch < pretrained_epoch: continue encoder.train() decoder.train() avg_loss = 0.0 for bi, (image_stories, targets_set, lengths_set, photo_squence_set, album_ids_set) in enumerate(train_data_loader): decoder.zero_grad() encoder.zero_grad() loss = 0 images = to_var(torch.stack(image_stories)) features, _ = encoder(images) for si, data in enumerate(zip(features, targets_set, lengths_set)): feature = data[0] captions = to_var(data[1]) lengths = data[2] outputs = decoder(feature, captions, lengths) for sj, result in enumerate(zip(outputs, captions, lengths)): loss += criterion(result[0], result[1][0:result[2]]) avg_loss += loss.item() loss /= (args.batch_size * 5) loss.backward() optimizer.step() # Print log info if bi % args.log_step == 0: print( 'Epoch [%d/%d], Train Step [%d/%d], Loss: %.4f, Perplexity: %5.4f' % (epoch + 1, args.num_epochs, bi, total_train_step, loss.item(), np.exp(loss.item()))) avg_loss /= (args.batch_size * total_train_step * 5) print( 'Epoch [%d/%d], Average Train Loss: %.4f, Average Train Perplexity: %5.4f' % (epoch + 1, args.num_epochs, avg_loss, np.exp(avg_loss))) # Save the models torch.save( decoder.state_dict(), os.path.join(args.model_path, 'decoder-%d.pkl' % (epoch + 1))) torch.save( encoder.state_dict(), os.path.join(args.model_path, 'encoder-%d.pkl' % (epoch + 1))) # Validation encoder.eval() decoder.eval() avg_loss = 0.0 for bi, (image_stories, targets_set, lengths_set, photo_sequence_set, album_ids_set) in enumerate(val_data_loader): loss = 0 images = to_var(torch.stack(image_stories)) features, _ = encoder(images) for si, data in enumerate(zip(features, targets_set, lengths_set)): feature = data[0] captions = to_var(data[1]) lengths = data[2] outputs = decoder(feature, captions, lengths) for sj, result in enumerate(zip(outputs, captions, lengths)): loss += criterion(result[0], result[1][0:result[2]]) avg_loss += loss.item() loss /= (args.batch_size * 5) # Print log info if bi % args.log_step == 0: print( 'Epoch [%d/%d], Val Step [%d/%d], Loss: %.4f, Perplexity: %5.4f' % (epoch + 1, args.num_epochs, bi, total_val_step, loss.item(), np.exp(loss.item()))) avg_loss /= (args.batch_size * total_val_step * 5) print( 'Epoch [%d/%d], Average Val Loss: %.4f, Average Val Perplexity: %5.4f' % (epoch + 1, args.num_epochs, avg_loss, np.exp(avg_loss))) #Termination Condition overfit_warn = overfit_warn + 1 if (min_avg_loss < avg_loss) else 0 min_avg_loss = min(min_avg_loss, avg_loss) if overfit_warn >= 10: break
]) with open(args.vocab_path, 'rb') as f: vocab = pickle.load(f) data_loader = get_loader(image_dir, sis_path, vocab, transform, args.batch_size, shuffle=False, num_workers=args.num_workers) encoder = EncoderStory(args.img_feature_size, args.hidden_size, args.num_layers) decoder = DecoderStory(args.embed_size, args.hidden_size, vocab) encoder.load_state_dict(torch.load(encoder_path)) decoder.load_state_dict(torch.load(decoder_path)) encoder.eval() decoder.eval() if torch.cuda.is_available(): encoder.cuda() decoder.cuda() print("Cuda is enabled...") criterion = nn.CrossEntropyLoss() results = []