def train(conf): train_loader, test_loader = load_dataset(512) net = VRNN(conf.x_dim, conf.h_dim, conf.z_dim) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") torch.cuda.manual_seed_all(112858) net.to(device) net = torch.nn.DataParallel(net, device_ids=[0, 1]) if conf.restore == True: net.load_state_dict( torch.load(conf.checkpoint_path, map_location='cuda:0')) print('Restore model from ' + conf.checkpoint_path) optimizer = optim.Adam(net.parameters(), lr=0.001) for ep in range(1, conf.train_epoch + 1): prog = Progbar(target=117) print("At epoch:{}".format(str(ep))) for i, (data, target) in enumerate(train_loader): data = data.squeeze(1) data = (data / 255).to(device) package = net(data) loss = Loss(package, data) net.zero_grad() loss.backward() _ = torch.nn.utils.clip_grad_norm_(net.parameters(), 5) optimizer.step() prog.update(i, exact=[("Training Loss", loss.item())]) with torch.no_grad(): x_decoded = net.module.sampling(conf.x_dim, device) x_decoded = x_decoded.cpu().numpy() digit = x_decoded.reshape(conf.x_dim, conf.x_dim) plt.imshow(digit, cmap='Greys_r') plt.pause(1e-6) if ep % conf.save_every == 0: torch.save(net.state_dict(), '../checkpoint/Epoch_' + str(ep + 1) + '.pth')
def main(args): # Create model directory if not os.path.exists(args.model_path): os.makedirs(args.model_path) # Load vocabulary wrapper. with open(args.vocab_path, 'rb') as f: vocab = pickle.load(f) # Image preprocessing # For normalization, see https://github.com/pytorch/vision#models transform = transforms.Compose([ transforms.RandomCrop(args.crop_size), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) ]) #val_loader = get_loader('./data/val_resized2014/', './data/annotations/captions_val2014.json', # vocab, transform, 1, False, 1) start_epoch = 0 encoder_state = args.encoder decoder_state = args.decoder # Build the models encoder = EncoderCNN(args.embed_size) if not args.train_encoder: encoder.eval() decoder = VRNN(args.embed_size, args.hidden_size, len(vocab), args.latent_size, args.num_layers) if args.restart: encoder_state, decoder_state = 'new', 'new' if encoder_state == '': encoder_state = 'new' if decoder_state == '': decoder_state = 'new' print("Using encoder: {}".format(encoder_state)) print("Using decoder: {}".format(decoder_state)) try: start_epoch = int(float(decoder_state.split('-')[1])) except: pass if encoder_state != 'new': encoder.load_state_dict(torch.load(encoder_state)) if decoder_state != 'new': decoder.load_state_dict(torch.load(decoder_state)) # Build data loader data_loader = get_loader(args.image_dir, args.caption_path, vocab, transform, args.batch_size, shuffle=True, num_workers=args.num_workers) """ Make logfile and log output """ with open(args.model_path + args.logfile, 'a+') as f: f.write("Using encoder: new\nUsing decoder: new\n\n") if torch.cuda.is_available(): encoder.cuda() decoder.cuda() # Optimizer cross_entropy = nn.CrossEntropyLoss() params = list(decoder.parameters()) + list( encoder.linear.parameters()) + list(encoder.bn.parameters()) optimizer = torch.optim.Adam(params, lr=args.learning_rate) batch_loss = [] batch_loss_det = [] batch_kl = [] batch_ml = [] batch_acc = [] # Train the Models total_step = len(data_loader) for epoch in range(start_epoch, args.num_epochs): for i, (images, captions, lengths, _, _) in enumerate(data_loader): # get lengths excluding <start> symbol lengths = [l - 1 for l in lengths] # Set mini-batch dataset images = to_var(images, volatile=True) captions = to_var(captions) # assuming following assertion assert min(lengths) > args.z_step + 2 # get targets from captions (excluding <start> tokens) #targets = pack_padded_sequence(captions[:,1:], lengths, batch_first=True)[0] targets_var = captions[:, args.z_step + 1] targets_det = pack_padded_sequence( captions[:, args.z_step + 2:], [l - args.z_step - 1 for l in lengths], batch_first=True)[0] # Get prior and approximate distributions decoder.zero_grad() encoder.zero_grad() features = encoder(images) prior, q_z, q_x, det_x = decoder(features, captions, lengths, z_step=args.z_step) # Calculate KL Divergence kl = torch.mean(kl_divergence(*q_z + prior)) # Get marginal likelihood from log likelihood of the correct symbol index = (torch.cuda.LongTensor(range(q_x.shape[0])), targets_var) ml = torch.mean(q_x[index]) # Get Cross-Entropy loss for deterministic decoder ce = cross_entropy(det_x, targets_det) elbo = ml - kl loss_var = -elbo loss_det = ce loss = loss_var + loss_det batch_loss.append(loss.data[0]) batch_loss_det.append(loss_det.data[0]) batch_kl.append(kl.data[0]) batch_ml.append(ml.data[0]) loss.backward() optimizer.step() # Print log info if i % args.log_step == 0: print( 'Epoch [%d/%d], Step [%d/%d], Loss: %.4f, Perplexity: %5.4f' % (epoch, args.num_epochs, i, total_step, loss.data[0], np.exp(loss.data[0]))) with open(args.model_path + args.logfile, 'a') as f: f.write( 'Epoch [%d/%d], Step [%d/%d], Loss: %.4f, Perplexity: %5.4f' % (epoch, args.num_epochs, i, total_step, loss.data[0], np.exp(loss.data[0]))) # Save the models if (i + 1) % args.save_step == 0: torch.save( decoder.state_dict(), os.path.join(args.model_path, 'decoder-%d-%d.pkl' % (epoch + 1, i + 1))) if args.train_encoder: torch.save( encoder.state_dict(), os.path.join(args.model_path, 'encoder-%d-%d.pkl' % (epoch + 1, i + 1))) with open(args.model_path + 'training_loss.pkl', 'w+') as f: pickle.dump(batch_loss, f) with open(args.model_path + 'training_val.pkl', 'w+') as f: pickle.dump(batch_acc, f) with open(args.model_path + args.logfile, 'a') as f: f.write("Training finished at {} .\n\n".format(str(datetime.now())))