def train(): ae = Autoencoder() # load trained model # model_path = '' # g.load_state_dict(torch.load(model_path)) criterion = torch.nn.MSELoss() optimizer = optim.Adam(ae.parameters(), lr=opt.lr, weight_decay=opt.decay) # load dataset # ========================== kwargs = dict(num_workers=1, pin_memory=True) if cuda else {} dataloader = DataLoader( datasets.MNIST('MNIST', download=True, transform=transforms.Compose([ transforms.ToTensor() ])), batch_size=opt.batch_size, shuffle=True, **kwargs ) N = len(dataloader) # get sample batch dataiter = iter(dataloader) samples, _ = dataiter.next() # cuda if cuda: ae.cuda() criterion.cuda() samples = samples.cuda() samples = Variable(samples) if opt.history: loss_history = np.empty(N*opt.epochs, dtype=np.float32) # train # ========================== for epoch in range(opt.epochs): loss_mean = 0.0 for i, (imgs, _) in enumerate(dataloader): if cuda: imgs = imgs.cuda() imgs = Variable(imgs) # forward & backward & update params ae.zero_grad() _, outputs = ae(imgs) loss = criterion(outputs, imgs) loss.backward() optimizer.step() loss_mean += loss.data[0] if opt.history: loss_history[N*epoch + i] = loss.data[0] show_progress(epoch+1, i+1, N, loss.data[0]) print('\ttotal loss (mean): %f' % (loss_mean/N)) # generate fake images _, reconst = ae(samples) vutils.save_image(reconst.data, os.path.join(IMAGE_PATH,'%d.png' % (epoch+1)), normalize=False) # save models torch.save(ae.state_dict(), MODEL_FULLPATH) # save loss history if opt.history: np.save('history/'+opt.name, loss_history)
def main(args): # Create model directory if not os.path.exists(args.model_path): os.makedirs(args.model_path) # Image preprocessing, normalization for the pretrained resnet transform = transforms.Compose([ transforms.RandomCrop(args.crop_size), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) ]) # Load vocabulary wrapper with open(args.vocab_path, 'rb') as f: vocab = pickle.load(f) # Build data loader data_loader = get_loader(args.image_dir, args.caption_path, vocab, transform, args.batch_size, shuffle=True, num_workers=args.num_workers) # Build the models encoder = EncoderCNN(args.embed_size).to(device) autoencoder = Autoencoder(args.embed_size, args.embeddings_path, args.hidden_size, len(vocab), args.num_layers).to(device) print(len(vocab)) # optimizer params = list( filter( lambda p: p.requires_grad, list(autoencoder.parameters())[1:] + list(encoder.linear.parameters()))) # print(params) optimizer = torch.optim.Adam(params, lr=args.learning_rate) # Define summary writer writer = SummaryWriter() # Loss tracker best_loss = float('inf') # Train the models total_step = len(data_loader) for epoch in range(args.num_epochs): for i, (images, captions, lengths) in enumerate(data_loader): # print(captions) # Set mini-batch dataset images = images.to(device) captions = captions.to(device) targets = pack_padded_sequence(captions, lengths, batch_first=True)[0] # Forward, backward and optimize features = encoder(images) L_ling, L_vis = autoencoder(features, captions, lengths) loss = 0.2 * L_ling + 0.8 * L_vis # Want visual loss to have bigger impact autoencoder.zero_grad() encoder.zero_grad() loss.backward() optimizer.step() # Save the model checkpoints when loss improves if loss.item() < best_loss: best_loss = loss print("Saving checkpoints") torch.save( autoencoder.state_dict(), os.path.join( args.model_path, 'autoencoder-frozen-best.ckpt'.format( epoch + 1, i + 1))) torch.save( encoder.state_dict(), os.path.join( args.model_path, 'encoder-frozen-best.ckpt'.format(epoch + 1, i + 1))) # Print log info if i % args.log_step == 0: print( 'Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}, Perplexity: {:5.4f}' .format(epoch, args.num_epochs, i, total_step, loss.item(), np.exp(loss.item()))) # Log train loss on tensorboard writer.add_scalar('frozen-loss/L_ling', L_ling.item(), epoch * total_step + i) writer.add_scalar('frozen-loss/L_vis', L_vis.item(), epoch * total_step + i) writer.add_scalar('frozen-loss/combined', loss.item(), epoch * total_step + i) # Save the model checkpoints if (i + 1) % args.save_step == 0: torch.save( autoencoder.state_dict(), os.path.join( args.model_path, 'autoencoder-frozen-{}-{}.ckpt'.format( epoch + 1, i + 1))) torch.save( encoder.state_dict(), os.path.join( args.model_path, 'encoder-frozen-{}-{}.ckpt'.format(epoch + 1, i + 1)))