def main(): # global args args = parser.parse_args() # <editor-fold desc="Initialization"> if args.comment == "test": print("WARNING: name is test!!!\n\n") # now = datetime.datetime.now() # current_date = now.strftime("%m-%d-%H-%M") assert args.text_criterion in ("MSE", "Cosine", "Hinge", "NLLLoss"), 'Invalid Loss Function' assert args.cm_criterion in ("MSE", "Cosine", "Hinge"), 'Invalid Loss Function' assert args.common_emb_ratio <= 1.0 and args.common_emb_ratio >= 0 mask = int(args.common_emb_ratio * args.hidden_size) cuda = args.cuda if cuda == 'true': cuda = True else: cuda = False if args.load_model == "NONE": keep_loading = False # model_path = args.model_path + current_date + "/" model_path = args.model_path + args.comment + "/" else: keep_loading = True model_path = args.model_path + args.load_model + "/" result_path = args.result_path if result_path == "NONE": result_path = model_path + "results/" if not os.path.exists(result_path): os.makedirs(result_path) if not os.path.exists(model_path): os.makedirs(model_path) #</editor-fold> # <editor-fold desc="Image Preprocessing"> # Image preprocessing //ATTENTION # For normalization, see https://github.com/pytorch/vision#models transform = transforms.Compose([ transforms.RandomCrop(args.crop_size), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) ]) #</editor-fold> # <editor-fold desc="Creating Embeddings"> # Load vocabulary wrapper. print("Loading Vocabulary...") with open(args.vocab_path, 'rb') as f: vocab = pickle.load(f) # Load Embeddings emb_size = args.word_embedding_size emb_path = args.embedding_path if args.embedding_path[-1] == '/': emb_path += 'glove.6B.' + str(emb_size) + 'd.txt' print("Loading Embeddings...") emb = load_glove_embeddings(emb_path, vocab.word2idx, emb_size) glove_emb = nn.Embedding(emb.size(0), emb.size(1)) # Freeze weighs if args.fixed_embeddings == "true": glove_emb.weight.requires_grad = False # </editor-fold> # <editor-fold desc="Data-Loaders"> # Build data loader print("Building Data Loader For Test Set...") data_loader = get_loader(args.image_dir, args.caption_path, vocab, transform, args.batch_size, shuffle=True, num_workers=args.num_workers) print("Building Data Loader For Validation Set...") val_loader = get_loader(args.valid_dir, args.valid_caption_path, vocab, transform, args.batch_size, shuffle=True, num_workers=args.num_workers) # </editor-fold> # <editor-fold desc="Network Initialization"> print("Setting up the Networks...") coupled_vae = CoupledVAE(glove_emb, len(vocab), hidden_size=args.hidden_size, latent_size=args.latent_size, batch_size=args.batch_size) if cuda: coupled_vae = coupled_vae.cuda() # </editor-fold> # </editor-fold> # <editor-fold desc="Optimizers"> print("Setting up the Optimizers...") vae_optim = optim.Adam(coupled_vae.parameters(), lr=args.learning_rate, betas=(0.5, 0.999), weight_decay=0.00001) # </editor-fold desc="Optimizers"> train_swapped = False # Reverse 2 step = 0 with open(os.path.join(result_path, "losses.csv"), "w") as text_file: text_file.write("Epoch, Img, Txt, CM\n") for epoch in range(args.num_epochs): # <editor-fold desc = "Epoch Initialization"? # TRAINING TIME print('EPOCH ::: TRAINING ::: ' + str(epoch + 1)) batch_time = AverageMeter() txt_losses = AverageMeter() img_losses = AverageMeter() cm_losses = AverageMeter() end = time.time() bar = Bar('Training Net', max=len(data_loader)) if keep_loading: suffix = "-" + str(epoch) + "-" + args.load_model + ".pkl" try: coupled_vae.load_state_dict( torch.load( os.path.join(args.model_path, 'coupled_vae' + suffix))) except FileNotFoundError: print("Didn't find any models switching to training") keep_loading = False if not keep_loading: # Set training mode coupled_vae.train() # </editor-fold desc = "Epoch Initialization"? train_swapped = not train_swapped for i, (images, captions, lengths) in enumerate(data_loader): if i == len(data_loader) - 1: break images = to_var(images) captions = to_var(captions) lengths = to_var( torch.LongTensor(lengths)) # print(captions.size()) # Forward, Backward and Optimize vae_optim.zero_grad() img_out, img_mu, img_logv, img_z, txt_out, txt_mu, txt_logv, txt_z = \ coupled_vae(images, captions, lengths, train_swapped) img_rc_loss = img_vae_loss( img_out, images, img_mu, img_logv) / (args.batch_size * args.crop_size**2) NLL_loss, KL_loss, KL_weight = seq_vae_loss( txt_out, captions, lengths, txt_mu, txt_logv, "logistic", step, 0.0025, 2500) txt_rc_loss = (NLL_loss + KL_weight * KL_loss) / torch.sum(lengths).float() txt_losses.update(txt_rc_loss.data[0], args.batch_size) img_losses.update(img_rc_loss.data[0], args.batch_size) loss = img_rc_loss + txt_rc_loss loss.backward() vae_optim.step() step += 1 if i % args.image_save_interval == 0: subdir_path = os.path.join( result_path, str(i / args.image_save_interval)) if os.path.exists(subdir_path): pass else: os.makedirs(subdir_path) for im_idx in range(3): # im_or = (images[im_idx].cpu().data.numpy().transpose(1,2,0))*255 # im = (img_out[im_idx].cpu().data.numpy().transpose(1,2,0))*255 im_or = (images[im_idx].cpu().data.numpy().transpose( 1, 2, 0) / 2 + .5) * 255 im = (img_out[im_idx].cpu().data.numpy().transpose( 1, 2, 0) / 2 + .5) * 255 # im = img_out[im_idx].cpu().data.numpy().transpose(1,2,0)*255 filename_prefix = os.path.join(subdir_path, str(im_idx)) scipy.misc.imsave(filename_prefix + '_original.A.jpg', im_or) scipy.misc.imsave(filename_prefix + '.A.jpg', im) txt_or = " ".join([ vocab.idx2word[c] for c in captions[im_idx].cpu().data.numpy() ]) _, generated = torch.topk(txt_out[im_idx], 1) txt = " ".join([ vocab.idx2word[c] for c in generated[:, 0].cpu().data.numpy() ]) with open(filename_prefix + "_captions.txt", "w") as text_file: text_file.write("Epoch %d\n" % epoch) text_file.write("Original: %s\n" % txt_or) text_file.write("Generated: %s" % txt) # measure elapsed time batch_time.update(time.time() - end) end = time.time() # plot progress bar.suffix = '({batch}/{size}) Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss_Img: {img_l:.3f}| Loss_Txt: {txt_l:.3f} | Loss_CM: {cm_l:.4f}'.format( batch=i, size=len(data_loader), bt=batch_time.avg, total=bar.elapsed_td, eta=bar.eta_td, img_l=img_losses.avg, txt_l=txt_losses.avg, cm_l=cm_losses.avg, ) bar.next() # </editor-fold desc = "Logging"> bar.finish() with open(os.path.join(result_path, "losses.csv"), "a") as text_file: text_file.write("{}, {}, {}, {}\n".format( epoch, img_losses.avg, txt_losses.avg, cm_losses.avg)) # <editor-fold desc = "Saving the models"? # Save the models print('\n') print('Saving the models in {}...'.format(model_path)) torch.save( coupled_vae.state_dict(), os.path.join(model_path, 'coupled_vae' % (epoch + 1)) + ".pkl")
def main(): print("Initializing...") # global args args = parser.parse_args() now = datetime.datetime.now() current_date = now.strftime("%m-%d-%H-%M") assert args.text_criterion in ("MSE", "Cosine", "Hinge"), 'Invalid Loss Function' assert args.cm_criterion in ("MSE", "Cosine", "Hinge"), 'Invalid Loss Function' mask = args.common_emb_size assert mask <= args.hidden_size cuda = args.cuda if cuda == 'true': cuda = True else: cuda = False # Image preprocessing //ATTENTION # For normalization, see https://github.com/pytorch/vision#models transform = transforms.Compose([ transforms.RandomCrop(args.crop_size), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) ]) result_path = args.result_path model_path = args.model_path + current_date + "/" if not os.path.exists(result_path): os.makedirs(result_path) if not os.path.exists(model_path): print("Creating model path on", model_path) os.makedirs(model_path) # Load vocabulary wrapper. print("Loading Vocabulary...") with open(args.vocab_path, 'rb') as f: vocab = pickle.load(f) # Load Embeddings emb_size = args.embedding_size emb_path = args.embedding_path if args.embedding_path[-1] == '/': emb_path += 'glove.6B.' + str(emb_size) + 'd.txt' print("Loading Embeddings...") emb = load_glove_embeddings(emb_path, vocab.word2idx, emb_size) glove_emb = Embeddings(emb_size, len(vocab.word2idx), vocab.word2idx["<pad>"]) glove_emb.word_lut.weight.data.copy_(emb) glove_emb.word_lut.weight.requires_grad = False # glove_emb = nn.Embedding(emb.size(0), emb.size(1)) # glove_emb = embedding(emb.size(0), emb.size(1)) # glove_emb.weight = nn.Parameter(emb) # Freeze weighs # if args.fixed_embeddings == "true": # glove_emb.weight.requires_grad = False # Build data loader print("Building Data Loader For Test Set...") data_loader = get_loader(args.image_dir, args.caption_path, vocab, transform, args.batch_size, shuffle=True, num_workers=args.num_workers) print("Building Data Loader For Validation Set...") val_loader = get_loader(args.valid_dir, args.valid_caption_path, vocab, transform, args.batch_size, shuffle=True, num_workers=args.num_workers) print("Setting up the Networks...") encoder_Img = ImageEncoder(img_dimension=args.crop_size, feature_dimension=args.hidden_size) decoder_Img = ImageDecoder(img_dimension=args.crop_size, feature_dimension=args.hidden_size) if cuda: encoder_Img = encoder_Img.cuda() decoder_Img = decoder_Img.cuda() # Losses and Optimizers print("Setting up the Objective Functions...") img_criterion = nn.MSELoss() # txt_criterion = nn.MSELoss(size_average=True) if cuda: img_criterion = img_criterion.cuda() # txt_criterion = nn.CrossEntropyLoss() # gen_params = chain(generator_A.parameters(), generator_B.parameters()) print("Setting up the Optimizers...") # img_params = chain(decoder_Img.parameters(), encoder_Img.parameters()) img_params = list(decoder_Img.parameters()) + list( encoder_Img.parameters()) # ATTENTION: Check betas and weight decay # ATTENTION: Check why valid_params fails on image networks with out of memory error img_optim = optim.Adam( img_params, lr=0.001) #,betas=(0.5, 0.999), weight_decay=0.00001) # img_enc_optim = optim.Adam(encoder_Img.parameters(), lr=args.learning_rate)#betas=(0.5, 0.999), weight_decay=0.00001) # img_dec_optim = optim.Adam(decoder_Img.parameters(), lr=args.learning_rate)#betas=(0.5,0.999), weight_decay=0.00001) train_images = False # Reverse 2 for epoch in range(args.num_epochs): # TRAINING TIME print('EPOCH ::: TRAINING ::: ' + str(epoch + 1)) batch_time = AverageMeter() img_losses = AverageMeter() txt_losses = AverageMeter() cm_losses = AverageMeter() end = time.time() bar = Bar('Training Net', max=len(data_loader)) # Set training mode encoder_Img.train() decoder_Img.train() train_images = True for i, (images, captions, lengths) in enumerate(data_loader): # ATTENTION REMOVE if i == 6450: break # Set mini-batch dataset images = to_var(images) captions = to_var(captions) # target = pack_padded_sequence(captions, lengths, batch_first=True)[0] # captions, lengths = pad_sequences(captions, lengths) # images = torch.FloatTensor(images) captions = captions.transpose(0, 1).unsqueeze(2) lengths = torch.LongTensor(lengths) # print(captions.size()) # Forward, Backward and Optimize # img_optim.zero_grad() # img_dec_optim.zero_grad() # img_enc_optim.zero_grad() encoder_Img.zero_grad() decoder_Img.zero_grad() # txt_params.zero_grad() # txt_dec_optim.zero_grad() # txt_enc_optim.zero_grad() # Image Auto_Encoder Forward img_encoder_outputs, Iz = encoder_Img(images) IzI = decoder_Img(img_encoder_outputs) img_rc_loss = img_criterion(IzI, images) # Text Auto Encoder Forward # target = target[:-1] # exclude last target from inputs img_loss = img_rc_loss img_losses.update(img_rc_loss.data[0], args.batch_size) txt_losses.update(0, args.batch_size) cm_losses.update(0, args.batch_size) # Image Network Training and Backpropagation img_loss.backward() img_optim.step() if i % args.image_save_interval == 0: subdir_path = os.path.join(result_path, str(i / args.image_save_interval)) if os.path.exists(subdir_path): pass else: os.makedirs(subdir_path) for im_idx in range(3): im_or = (images[im_idx].cpu().data.numpy().transpose( 1, 2, 0) / 2 + .5) * 255 im = (IzI[im_idx].cpu().data.numpy().transpose(1, 2, 0) / 2 + .5) * 255 filename_prefix = os.path.join(subdir_path, str(im_idx)) scipy.misc.imsave(filename_prefix + '_original.A.jpg', im_or) scipy.misc.imsave(filename_prefix + '.A.jpg', im) # measure elapsed time batch_time.update(time.time() - end) end = time.time() # plot progress bar.suffix = '({batch}/{size}) Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss_Img: {img_l:.3f}| Loss_Txt: {txt_l:.3f} | Loss_CM: {cm_l:.4f}'.format( batch=i, size=len(data_loader), bt=batch_time.avg, total=bar.elapsed_td, eta=bar.eta_td, img_l=img_losses.avg, txt_l=txt_losses.avg, cm_l=cm_losses.avg, ) bar.next() bar.finish() # Save the models print('\n') print('Saving the models in {}...'.format(model_path)) torch.save( decoder_Img.state_dict(), os.path.join(model_path, 'decoder-img-%d-' % (epoch + 1)) + current_date + ".pkl") torch.save( encoder_Img.state_dict(), os.path.join(model_path, 'encoder-img-%d-' % (epoch + 1)) + current_date + ".pkl")
def main(): # global args args = parser.parse_args() # <editor-fold desc="Initialization"> if args.comment == "NONE": args.comment = args.method validate = args.validate == "true" if args.method == "coupled_vae_gan": trainer = coupled_vae_gan_trainer.coupled_vae_gan_trainer elif args.method == "coupled_vae": trainer = coupled_vae_trainer.coupled_vae_trainer elif args.method == "wgan": trainer = wgan_trainer.wgan_trainer elif args.method == "seq_wgan": trainer = seq_wgan_trainer.wgan_trainer elif args.method == "skip_thoughts": trainer = skipthoughts_vae_gan_trainer.coupled_vae_gan_trainer else: assert False, "Invalid method" # now = datetime.datetime.now() # current_date = now.strftime("%m-%d-%H-%M") assert args.text_criterion in ("MSE", "Cosine", "Hinge", "NLLLoss"), 'Invalid Loss Function' assert args.cm_criterion in ("MSE", "Cosine", "Hinge"), 'Invalid Loss Function' assert args.common_emb_ratio <= 1.0 and args.common_emb_ratio >= 0 #</editor-fold> # <editor-fold desc="Image Preprocessing"> # Image preprocessing //ATTENTION # For normalization, see https://github.com/pytorch/vision#models transform = transforms.Compose([ transforms.RandomCrop(args.crop_size), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((.5, .5, .5), (.5, .5, .5)) # transforms.Normalize((0.485, 0.456, 0.406), # (0.229, 0.224, 0.225)) ]) #</editor-fold> # <editor-fold desc="Creating Embeddings"> if args.dataset != "coco": args.vocab_path = "./data/cub_vocab.pkl" # Load vocabulary wrapper. print("Loading Vocabulary...") with open(args.vocab_path, 'rb') as f: vocab = pickle.load(f) # Load Embeddings emb_size = args.word_embedding_size emb_path = args.embedding_path if args.embedding_path[-1] == '/': emb_path += 'glove.6B.' + str(emb_size) + 'd.txt' print("Loading Embeddings...") use_glove = args.use_glove == "true" if use_glove: emb = load_glove_embeddings(emb_path, vocab.word2idx, emb_size) word_emb = nn.Embedding(emb.size(0), emb.size(1)) word_emb.weight = nn.Parameter(emb) else: word_emb = nn.Embedding(len(vocab), emb_size) # Freeze weighs if args.fixed_embeddings == "true": word_emb.weight.requires_grad = True # </editor-fold> # <editor-fold desc="Data-Loaders"> # Build data loader print("Building Data Loader For Test Set...") if args.dataset == 'coco': data_loader = get_loader(args.image_dir, args.caption_path, vocab, transform, args.batch_size, shuffle=True, num_workers=args.num_workers) print("Building Data Loader For Validation Set...") val_loader = get_loader(args.valid_dir, args.valid_caption_path, vocab, transform, args.batch_size, shuffle=True, num_workers=args.num_workers) else: data_path = "data/cub.h5" dataset = Text2ImageDataset(data_path, split=0, vocab=vocab, transform=transform) data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, collate_fn=collate_fn) dataset_val = Text2ImageDataset(data_path, split=1, vocab=vocab, transform=transform) val_loader = DataLoader(dataset_val, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, collate_fn=collate_fn) # </editor-fold> txt_rc_loss = self.networks["coupled_vae"].text_reconstruction_loss(captions, txt2txt_out, lengths) # <editor-fold desc="Network Initialization"> print("Setting up the trainer...") model_trainer = trainer(args, word_emb, vocab) # <\editor-fold desc="Network Initialization"> for epoch in range(args.num_epochs): # <editor-fold desc = "Epoch Initialization"? # TRAINING TIME print('EPOCH ::: TRAINING ::: ' + str(epoch + 1)) batch_time = AverageMeter() end = time.time() bar = Bar(args.method if args.comment == "NONE" else args.method + "/" + args.comment, max=len(data_loader)) model_trainer.set_train_models() model_trainer.create_losses_meter(model_trainer.losses) for i, (images, captions, lengths) in enumerate(data_loader): if model_trainer.load_models(epoch): break # if i == 1: if i == len(data_loader) - 1: break images = to_var(images) # captions = to_var(captions[:,1:]) captions = to_var(captions) # lengths = to_var(torch.LongTensor(lengths) - 1) # print(captions.size()) lengths = to_var( torch.LongTensor(lengths)) # print(captions.size()) model_trainer.forward(epoch, images, captions, lengths, not i % args.image_save_interval) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if not model_trainer.iteration % args.log_step: # plot progress bar.suffix = bcolors.HEADER # bar.suffix += '({batch}/{size}) Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:}\n'.format( bar.suffix += '({batch}/{size}) Iter: {bt:} | Time: {total:}-{eta:}\n'.format( batch=i, size=len(data_loader), # bt=batch_time.val, bt=model_trainer.iteration, total=bar.elapsed_td, eta=bar.eta_td, ) bar.suffix += bcolors.ENDC cnt = 0 for l_name, l_value in sorted(model_trainer.losses.items(), key=lambda x: x[0]): cnt += 1 bar.suffix += ' | {name}: {val:.3f}'.format( name=l_name, val=l_value.avg, ) if not cnt % 5: bar.suffix += "\n" bar.next() # </editor-fold desc = "Logging"> bar.finish() if validate: print('EPOCH ::: VALIDATION ::: ' + str(epoch + 1)) batch_time = AverageMeter() end = time.time() barName = args.method if args.comment == "NONE" else args.method + "/" + args.comment barName = "VAL:" + barName bar = Bar(barName, max=len(val_loader)) model_trainer.set_eval_models() model_trainer.create_metrics_meter(model_trainer.metrics) for i, (images, captions, lengths) in enumerate(val_loader): # if not model_trainer.keep_loading and not model_trainer.iteration % args.model: # model_trainer.save_models(epoch) if i == len(val_loader) - 1: break images = to_var(images) captions = to_var(captions[:, 1:]) # lengths = to_var(torch.LongTensor(lengths - 1)) # print(captions.size()) model_trainer.evaluate(epoch, images, captions, lengths, i == 0) # measure elapsed time batch_time.update(time.time() - end) end = time.time() # plot progress bar.suffix = bcolors.HEADER # bar.suffix += '({batch}/{size}) Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:}\n'.format( bar.suffix += '({batch}/{size}) Iter: {bt:} | Time: {total:}-{eta:}\n'.format( batch=i, size=len(val_loader), # bt=batch_time.val, bt=model_trainer.iteration, total=bar.elapsed_td, eta=bar.eta_td, ) bar.suffix += bcolors.ENDC cnt = 0 for l_name, l_value in sorted(model_trainer.metrics.items(), key=lambda x: x[0]): cnt += 1 bar.suffix += ' | {name}: {val:.3f}'.format( name=l_name, val=l_value.avg, ) if not cnt % 5: bar.suffix += "\n" bar.next() bar.finish() # model_trainer.validate(val_loader) model_trainer.save_models(-1)
def main(): # global args args = parser.parse_args() # <editor-fold desc="Initialization"> if args.comment == "test": print("WARNING: name is test!!!\n\n") # now = datetime.datetime.now() # current_date = now.strftime("%m-%d-%H-%M") assert args.text_criterion in ("MSE", "Cosine", "Hinge", "NLLLoss"), 'Invalid Loss Function' assert args.cm_criterion in ("MSE", "Cosine", "Hinge"), 'Invalid Loss Function' assert args.common_emb_ratio <= 1.0 and args.common_emb_ratio >= 0 #</editor-fold> # <editor-fold desc="Image Preprocessing"> # Image preprocessing //ATTENTION # For normalization, see https://github.com/pytorch/vision#models transform = transforms.Compose([ transforms.RandomCrop(args.crop_size), transforms.RandomHorizontalFlip(), transforms.ToTensor(), # transforms.Normalize((.5,.5,.5), # (.5, .5, .5)) transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) ]) #</editor-fold> # <editor-fold desc="Creating Embeddings"> # Load vocabulary wrapper. print("Loading Vocabulary...") with open(args.vocab_path, 'rb') as f: vocab = pickle.load(f) # Load Embeddings emb_size = args.word_embedding_size emb_path = args.embedding_path if args.embedding_path[-1] == '/': emb_path += 'glove.6B.' + str(emb_size) + 'd.txt' print("Loading Embeddings...") emb = load_glove_embeddings(emb_path, vocab.word2idx, emb_size) glove_emb = nn.Embedding(emb.size(0), emb.size(1)) # Freeze weighs if args.fixed_embeddings == "true": glove_emb.weight.requires_grad = False # </editor-fold> # <editor-fold desc="Data-Loaders"> # Build data loader print("Building Data Loader For Test Set...") data_loader = get_loader(args.image_dir, args.caption_path, vocab, transform, args.batch_size, shuffle=True, num_workers=args.num_workers) print("Building Data Loader For Validation Set...") val_loader = get_loader(args.valid_dir, args.valid_caption_path, vocab, transform, args.batch_size, shuffle=True, num_workers=args.num_workers) # </editor-fold> # <editor-fold desc="Network Initialization"> print("Setting up the trainer...") model_trainer = trainer(args, glove_emb, vocab) # <\editor-fold desc="Network Initialization"> for epoch in range(args.num_epochs): # <editor-fold desc = "Epoch Initialization"? # TRAINING TIME print('EPOCH ::: TRAINING ::: ' + str(epoch + 1)) batch_time = AverageMeter() cm_losses = AverageMeter() end = time.time() bar = Bar('Training Net', max=len(data_loader)) for i, (images, captions, lengths) in enumerate(data_loader): if i == len(data_loader) - 1: break images = to_var(images) captions = to_var(captions) lengths = to_var( torch.LongTensor(lengths)) # print(captions.size()) img_rc_loss, txt_rc_loss = model_trainer.train( images, captions, lengths, not i % args.image_save_interval) txt_losses.update(txt_rc_loss.data[0], args.batch_size) img_losses.update(img_rc_loss.data[0], args.batch_size) # cm_losses.update(cm_loss.data[0], args.batch_size) batch_time.update(time.time() - end) end = time.time() # plot progress bar_suffix = '({batch}/{size}) Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:}'.format( batch=i, size=len(data_loader), bt=batch_time.avg, total=bar.elapsed_td, eta=bar.eta_td, ) bar.next() # </editor-fold desc = "Logging"> bar.finish() model_trainer.save_losses(epoch, img_losses.avg, txt_losses.avg) model_trainer.save_models(epoch)
def main(): # global args args = parser.parse_args() writer = SummaryWriter() # <editor-fold desc="Initialization"> now = datetime.datetime.now() current_date = now.strftime("%m-%d-%H-%M") assert args.text_criterion in ("MSE", "Cosine", "Hinge", "NLLLoss"), 'Invalid Loss Function' assert args.cm_criterion in ("MSE", "Cosine", "Hinge"), 'Invalid Loss Function' mask = args.common_emb_size assert mask <= args.hidden_size cuda = args.cuda if cuda == 'true': cuda = True else: cuda = False model_path = args.model_path + current_date + args.comment + "/" result_path = args.result_path if result_path == "NONE": result_path = model_path + "results/" if not os.path.exists(result_path): os.makedirs(result_path) if not os.path.exists(model_path): os.makedirs(model_path) #</editor-fold> # <editor-fold desc="Image Preprocessing"> # Image preprocessing //ATTENTION # For normalization, see https://github.com/pytorch/vision#models transform = transforms.Compose([ transforms.RandomCrop(args.crop_size), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) ]) #</editor-fold> # <editor-fold desc="Creating Embeddings"> # Load vocabulary wrapper. print("Loading Vocabulary...") with open(args.vocab_path, 'rb') as f: vocab = pickle.load(f) # Load Embeddings emb_size = args.embedding_size emb_path = args.embedding_path if args.embedding_path[-1] == '/': emb_path += 'glove.6B.' + str(emb_size) + 'd.txt' print("Loading Embeddings...") emb = load_glove_embeddings(emb_path, vocab.word2idx, emb_size) glove_emb = Embeddings(emb_size, len(vocab.word2idx), vocab.word2idx["<pad>"]) glove_emb.word_lut.weight.data.copy_(emb) glove_emb.word_lut.weight.requires_grad = False # glove_emb = nn.Embedding(emb.size(0), emb.size(1)) # glove_emb = embedding(emb.size(0), emb.size(1)) # glove_emb.weight = nn.Parameter(emb) # Freeze weighs # if args.fixed_embeddings == "true": # glove_emb.weight.requires_grad = False # </editor-fold> # <editor-fold desc="Data-Loaders"> # Build data loader print("Building Data Loader For Test Set...") data_loader = get_loader(args.image_dir, args.caption_path, vocab, transform, args.batch_size, shuffle=True, num_workers=args.num_workers) print("Building Data Loader For Validation Set...") val_loader = get_loader(args.valid_dir, args.valid_caption_path, vocab, transform, args.batch_size, shuffle=True, num_workers=args.num_workers) # </editor-fold> # <editor-fold desc="Network Initialization"> print("Setting up the Networks...") encoder_Txt = TextEncoder(glove_emb, num_layers=1, bidirectional=False, hidden_size=args.hidden_size) decoder_Txt = TextDecoder(glove_emb, num_layers=1, bidirectional=False, hidden_size=args.hidden_size) # decoder_Txt = TextDecoder(encoder_Txt, glove_emb) # decoder_Txt = DecoderRNN(glove_emb, hidden_size=args.hidden_size) encoder_Img = ImageEncoder(img_dimension=args.crop_size, feature_dimension=args.hidden_size) decoder_Img = ImageDecoder(img_dimension=args.crop_size, feature_dimension=args.hidden_size) if cuda: encoder_Txt = encoder_Txt.cuda() decoder_Img = decoder_Img.cuda() encoder_Img = encoder_Img.cuda() decoder_Txt = decoder_Txt.cuda() # </editor-fold> # <editor-fold desc="Losses"> # Losses and Optimizers print("Setting up the Objective Functions...") img_criterion = nn.MSELoss() # txt_criterion = nn.MSELoss(size_average=True) if args.text_criterion == 'MSE': txt_criterion = nn.MSELoss() elif args.text_criterion == "Cosine": txt_criterion = nn.CosineEmbeddingLoss(size_average=False) else: txt_criterion = nn.HingeEmbeddingLoss(size_average=False) if args.cm_criterion == 'MSE': cm_criterion = nn.MSELoss() elif args.cm_criterion == "Cosine": cm_criterion = nn.CosineEmbeddingLoss() else: cm_criterion = nn.HingeEmbeddingLoss() if cuda: img_criterion = img_criterion.cuda() txt_criterion = txt_criterion.cuda() cm_criterion = cm_criterion.cuda() # txt_criterion = nn.CrossEntropyLoss() # </editor-fold> # <editor-fold desc="Optimizers"> # gen_params = chain(generator_A.parameters(), generator_B.parameters()) print("Setting up the Optimizers...") # img_params = chain(decoder_Img.parameters(), encoder_Img.parameters()) # txt_params = chain(decoder_Txt.decoder.parameters(), encoder_Txt.encoder.parameters()) # img_params = list(decoder_Img.parameters()) + list(encoder_Img.parameters()) # txt_params = list(decoder_Txt.decoder.parameters()) + list(encoder_Txt.encoder.parameters()) # ATTENTION: Check betas and weight decay # ATTENTION: Check why valid_params fails on image networks with out of memory error # img_optim = optim.Adam(img_params, lr=0.0001, betas=(0.5, 0.999), weight_decay=0.00001) # txt_optim = optim.Adam(valid_params(txt_params), lr=0.0001,betas=(0.5, 0.999), weight_decay=0.00001) img_enc_optim = optim.Adam( encoder_Img.parameters(), lr=args.learning_rate) #betas=(0.5, 0.999), weight_decay=0.00001) img_dec_optim = optim.Adam( decoder_Img.parameters(), lr=args.learning_rate) #betas=(0.5,0.999), weight_decay=0.00001) txt_enc_optim = optim.Adam( valid_params(encoder_Txt.encoder.parameters()), lr=args.learning_rate) #betas=(0.5,0.999), weight_decay=0.00001) txt_dec_optim = optim.Adam( valid_params(decoder_Txt.decoder.parameters()), lr=args.learning_rate) #betas=(0.5,0.999), weight_decay=0.00001) # </editor-fold desc="Optimizers"> train_images = False # Reverse 2 for epoch in range(args.num_epochs): # <editor-fold desc = "Epoch Initialization"? # TRAINING TIME print('EPOCH ::: TRAINING ::: ' + str(epoch + 1)) batch_time = AverageMeter() txt_losses = AverageMeter() img_losses = AverageMeter() cm_losses = AverageMeter() end = time.time() bar = Bar('Training Net', max=len(data_loader)) # Set training mode encoder_Img.train() decoder_Img.train() encoder_Txt.encoder.train() decoder_Txt.decoder.train() neg_rate = max(0, 2 * (10 - epoch) / 10) # </editor-fold desc = "Epoch Initialization"? train_images = not train_images for i, (images, captions, lengths) in enumerate(data_loader): # ATTENTION REMOVE if i == len(data_loader) - 1: break # <editor-fold desc = "Training Parameters Initiliazation"? # Set mini-batch dataset images = to_var(images) captions = to_var(captions) # target = pack_padded_sequence(captions, lengths, batch_first=True)[0] # captions, lengths = pad_sequences(captions, lengths) # images = torch.FloatTensor(images) captions = captions.transpose(0, 1).unsqueeze(2) lengths = torch.LongTensor(lengths) # print(captions.size()) # Forward, Backward and Optimize # img_optim.zero_grad() img_dec_optim.zero_grad() img_enc_optim.zero_grad() # encoder_Img.zero_grad() # decoder_Img.zero_grad() # txt_params.zero_grad() txt_dec_optim.zero_grad() txt_enc_optim.zero_grad() # encoder_Txt.encoder.zero_grad() # decoder_Txt.decoder.zero_grad() # </editor-fold desc = "Training Parameters Initiliazation"? # <editor-fold desc = "Image AE"? # Image Auto_Encoder Forward img_encoder_outputs, Iz = encoder_Img(images) IzI = decoder_Img(img_encoder_outputs) img_rc_loss = img_criterion(IzI, images) # </editor-fold desc = "Image AE"? # <editor-fold desc = "Seq2Seq AE"? # Text Auto Encoder Forward # target = target[:-1] # exclude last target from inputs captions = captions[:-1, :, :] lengths = lengths - 1 dec_state = None encoder_outputs, memory_bank = encoder_Txt(captions, lengths) enc_state = \ decoder_Txt.decoder.init_decoder_state(captions, memory_bank, encoder_outputs) decoder_outputs, dec_state, attns = \ decoder_Txt.decoder(captions, memory_bank, enc_state if dec_state is None else dec_state, memory_lengths=lengths) Tz = encoder_outputs TzT = decoder_outputs # </editor-fold desc = "Seq2Seq AE"? # <editor-fold desc = "Loss accumulation"? if args.text_criterion == 'MSE': txt_rc_loss = txt_criterion(TzT, glove_emb(captions)) else: txt_rc_loss = txt_criterion(TzT, glove_emb(captions),\ Variable(torch.ones(TzT.size(0,1))).cuda()) # # for x,y,l in zip(TzT.transpose(0,1),glove_emb(captions).transpose(0,1),lengths): # if args.criterion == 'MSE': # # ATTENTION dunno what's the right one # txt_rc_loss += txt_criterion(x,y) # else: # # ATTENTION Fails on last batch # txt_rc_loss += txt_criterion(x, y, Variable(torch.ones(x.size(0))).cuda())/l # # txt_rc_loss /= captions.size(1) # Computes Cross-Modal Loss Tz = Tz[0] txt = Tz.narrow(1, 0, mask) im = Iz.narrow(1, 0, mask) if args.cm_criterion == 'MSE': # cm_loss = cm_criterion(Tz.narrow(1,0,mask), Iz.narrow(1,0,mask)) cm_loss = mse_loss(txt, im) else: cm_loss = cm_criterion(txt, im, \ Variable(torch.ones(im.size(0)).cuda())) # K - Negative Samples k = args.negative_samples for _ in range(k): if cuda: perm = torch.randperm(args.batch_size).cuda() else: perm = torch.randperm(args.batch_size) # if args.criterion == 'MSE': # cm_loss -= mse_loss(txt, im[perm])/k # else: # cm_loss -= cm_criterion(txt, im[perm], \ # Variable(torch.ones(Tz.narrow(1,0,mask).size(0)).cuda()))/k # sim = (F.cosine_similarity(txt,txt[perm]) - 0.5)/2 if args.cm_criterion == 'MSE': sim = (F.cosine_similarity(txt, txt[perm]) - 1) / (2 * k) # cm_loss = cm_criterion(Tz.narrow(1,0,mask), Iz.narrow(1,0,mask)) cm_loss += mse_loss(txt, im[perm], sim) else: cm_loss += neg_rate * cm_criterion(txt, im[perm], \ Variable(-1*torch.ones(txt.size(0)).cuda()))/k # cm_loss = Variable(torch.max(torch.FloatTensor([-0.100]).cuda(), cm_loss.data)) # Computes the loss to be back-propagated img_loss = img_rc_loss * ( 1 - args.cm_loss_weight) + cm_loss * args.cm_loss_weight txt_loss = txt_rc_loss * ( 1 - args.cm_loss_weight) + cm_loss * args.cm_loss_weight # txt_loss = txt_rc_loss + 0.1 * cm_loss # img_loss = img_rc_loss + cm_loss txt_losses.update(txt_rc_loss.data[0], args.batch_size) img_losses.update(img_rc_loss.data[0], args.batch_size) cm_losses.update(cm_loss.data[0], args.batch_size) # </editor-fold desc = "Loss accumulation"? # <editor-fold desc = "Back Propagation"> # Half of the times we update one pipeline the others the other one if train_images: # Image Network Training and Backpropagation img_loss.backward() # img_optim.step() img_enc_optim.step() img_dec_optim.step() else: # Text Nextwork Training & Back Propagation txt_loss.backward() # txt_optim.step() txt_enc_optim.step() txt_dec_optim.step() # </editor-fold desc = "Back Propagation"> # <editor-fold desc = "Logging"> if i % args.image_save_interval == 0: subdir_path = os.path.join(result_path, str(i / args.image_save_interval)) if os.path.exists(subdir_path): pass else: os.makedirs(subdir_path) for im_idx in range(3): im_or = (images[im_idx].cpu().data.numpy().transpose( 1, 2, 0) / 2 + .5) * 255 im = (IzI[im_idx].cpu().data.numpy().transpose(1, 2, 0) / 2 + .5) * 255 filename_prefix = os.path.join(subdir_path, str(im_idx)) scipy.misc.imsave(filename_prefix + '_original.A.jpg', im_or) scipy.misc.imsave(filename_prefix + '.A.jpg', im) # measure elapsed time batch_time.update(time.time() - end) end = time.time() # plot progress bar.suffix = '({batch}/{size}) Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss_Img: {img_l:.3f}| Loss_Txt: {txt_l:.3f} | Loss_CM: {cm_l:.4f}'.format( batch=i, size=len(data_loader), bt=batch_time.avg, total=bar.elapsed_td, eta=bar.eta_td, img_l=img_losses.avg, txt_l=txt_losses.avg, cm_l=cm_losses.avg, ) bar.next() # </editor-fold desc = "Logging"> bar.finish() # <editor-fold desc = "Saving the models"? # Save the models print('\n') print('Saving the models in {}...'.format(model_path)) torch.save( decoder_Img.state_dict(), os.path.join(model_path, 'decoder-img-%d-' % (epoch + 1)) + current_date + ".pkl") torch.save( encoder_Img.state_dict(), os.path.join(model_path, 'encoder-img-%d-' % (epoch + 1)) + current_date + ".pkl") torch.save( decoder_Txt.state_dict(), os.path.join(model_path, 'decoder-txt-%d-' % (epoch + 1)) + current_date + ".pkl") torch.save( encoder_Txt.state_dict(), os.path.join(model_path, 'encoder-txt-%d-' % (epoch + 1)) + current_date + ".pkl") # </editor-fold desc = "Saving the models"? # <editor-fold desc = "Validation"> if args.validate == "true": print("Train Set") validate(encoder_Img, encoder_Txt, data_loader, mask, 10) print("Test Set") validate(encoder_Img, encoder_Txt, val_loader, mask, 10) # </editor-fold desc = "Validation"> writer.add_scalars( 'data/scalar_group', { 'Image_RC': img_losses.avg, 'Text_RC': txt_losses.avg, 'CM_loss': cm_losses.avg }, epoch)
def main(): # global args args = parser.parse_args() # <editor-fold desc="Initialization"> now = datetime.datetime.now() current_date = now.strftime("%m-%d-%H-%M") assert args.text_criterion in ("MSE","Cosine","Hinge","NLLLoss"), 'Invalid Loss Function' assert args.cm_criterion in ("MSE","Cosine","Hinge"), 'Invalid Loss Function' mask = int(args.common_emb_percentage * args.hidden_size) assert mask <= args.hidden_size cuda = args.cuda if cuda == 'true': cuda = True else: cuda = False if args.load_model == "NONE": keep_loading = True model_path = args.model_path + current_date + "/" else: keep_loading = False model_path = args.model_path + args.load_model + "/" result_path = args.result_path if result_path == "NONE": result_path = model_path + "results/" if not os.path.exists(result_path): os.makedirs(result_path) if not os.path.exists(model_path): os.makedirs(model_path) #</editor-fold> # <editor-fold desc="Image Preprocessing"> # Image preprocessing //ATTENTION # For normalization, see https://github.com/pytorch/vision#models transform = transforms.Compose([ transforms.RandomCrop(args.crop_size), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]) #</editor-fold> # <editor-fold desc="Creating Embeddings"> # Load vocabulary wrapper. print("Loading Vocabulary...") with open(args.vocab_path, 'rb') as f: vocab = pickle.load(f) # Load Embeddings emb_size = args.embedding_size emb_path = args.embedding_path if args.embedding_path[-1]=='/': emb_path += 'glove.6B.' + str(emb_size) + 'd.txt' print("Loading Embeddings...") emb = load_glove_embeddings(emb_path, vocab.word2idx, emb_size) # glove_emb = Embeddings(emb_size,len(vocab.word2idx),vocab.word2idx["<pad>"]) # glove_emb.word_lut.weight.data.copy_(emb) # glove_emb.word_lut.weight.requires_grad = False glove_emb = nn.Embedding(emb.size(0), emb.size(1)) # glove_emb = embedding(emb.size(0), emb.size(1)) # glove_emb.weight = nn.Parameter(emb) # Freeze weighs # if args.fixed_embeddings == "true": # glove_emb.weight.requires_grad = False # </editor-fold> # <editor-fold desc="Data-Loaders"> # Build data loader print("Building Data Loader For Test Set...") data_loader = get_loader(args.image_dir, args.caption_path, vocab, transform, args.batch_size, shuffle=True, num_workers=args.num_workers) print("Building Data Loader For Validation Set...") val_loader = get_loader(args.valid_dir, args.valid_caption_path, vocab, transform, args.batch_size, shuffle=True, num_workers=args.num_workers) # </editor-fold> # <editor-fold desc="Network Initialization"> print("Setting up the Networks...") encoder_Txt = TextEncoder(glove_emb, num_layers=1, bidirectional=False, hidden_size=args.hidden_size) decoder_Txt = TextDecoder(glove_emb, len(vocab), num_layers=1, bidirectional=False, hidden_size=args.hidden_size) # decoder_Txt = TextDecoder(encoder_Txt, glove_emb) # decoder_Txt = DecoderRNN(glove_emb, hidden_size=args.hidden_size) encoder_Img = ImageEncoder(img_dimension=args.crop_size,feature_dimension= args.hidden_size) decoder_Img = ImageDecoder(img_dimension=args.crop_size, feature_dimension= args.hidden_size) if cuda: encoder_Txt = encoder_Txt.cuda() decoder_Img = decoder_Img.cuda() encoder_Img = encoder_Img.cuda() decoder_Txt = decoder_Txt.cuda() # </editor-fold> # <editor-fold desc="Losses"> # Losses and Optimizers print("Setting up the Objective Functions...") img_criterion = nn.MSELoss() # txt_criterion = nn.MSELoss(size_average=True) if args.text_criterion == 'MSE': txt_criterion = nn.MSELoss() elif args.text_criterion == "Cosine": txt_criterion = nn.CosineEmbeddingLoss(size_average=False) elif args.text_criterion == "NLLLoss": txt_criterion = nn.NLLLoss() else: txt_criterion = nn.HingeEmbeddingLoss(size_average=False) if args.cm_criterion == 'MSE': cm_criterion = nn.MSELoss() elif args.cm_criterion == "Cosine": cm_criterion = nn.CosineEmbeddingLoss() else: cm_criterion = nn.HingeEmbeddingLoss() if cuda: img_criterion = img_criterion.cuda() txt_criterion = txt_criterion.cuda() cm_criterion = cm_criterion.cuda() # txt_criterion = nn.CrossEntropyLoss() # </editor-fold> # <editor-fold desc="Optimizers"> # gen_params = chain(generator_A.parameters(), generator_B.parameters()) print("Setting up the Optimizers...") # img_params = chain(decoder_Img.parameters(), encoder_Img.parameters()) # txt_params = chain(decoder_Txt.decoder.parameters(), encoder_Txt.encoder.parameters()) # img_params = list(decoder_Img.parameters()) + list(encoder_Img.parameters()) # txt_params = list(decoder_Txt.decoder.parameters()) + list(encoder_Txt.encoder.parameters()) # ATTENTION: Check betas and weight decay # ATTENTION: Check why valid_params fails on image networks with out of memory error # img_optim = optim.Adam(img_params, lr=0.0001, betas=(0.5, 0.999), weight_decay=0.00001) # txt_optim = optim.Adam(valid_params(txt_params), lr=0.0001,betas=(0.5, 0.999), weight_decay=0.00001) img_enc_optim = optim.Adam(encoder_Img.parameters(), lr=args.learning_rate)#betas=(0.5, 0.999), weight_decay=0.00001) img_dec_optim = optim.Adam(decoder_Img.parameters(), lr=args.learning_rate)#betas=(0.5,0.999), weight_decay=0.00001) txt_enc_optim = optim.Adam(valid_params(encoder_Txt.parameters()), lr=args.learning_rate)#betas=(0.5,0.999), weight_decay=0.00001) txt_dec_optim = optim.Adam(valid_params(decoder_Txt.parameters()), lr=args.learning_rate)#betas=(0.5,0.999), weight_decay=0.00001) # </editor-fold desc="Optimizers"> train_images = False # Reverse 2 for epoch in range(args.num_epochs): # <editor-fold desc = "Epoch Initialization"? # TRAINING TIME print('EPOCH ::: TRAINING ::: ' + str(epoch + 1)) batch_time = AverageMeter() txt_losses = AverageMeter() img_losses = AverageMeter() cm_losses = AverageMeter() end = time.time() bar = Bar('Training Net', max=len(data_loader)) if keep_loading: suffix = "-" + str(epoch) + "-" + args.load_model + ".pkl" try: encoder_Img.load_state_dict(torch.load(os.path.join(args.model_path, 'encoder-img' + suffix))) encoder_Txt.load_state_dict(torch.load(os.path.join(args.model_path, 'encoder-txt' + suffix))) decoder_Img.load_state_dict(torch.load(os.path.join(args.model_path, 'decoder-img' + suffix))) decoder_Txt.load_state_dict(torch.load(os.path.join(args.model_path, 'decoder-txt' + suffix))) except FileNotFoundError: print("Didn't find any models switching to training") keep_loading = False if not keep_loading: # Set training mode encoder_Img.train() decoder_Img.train() encoder_Txt.train() decoder_Txt.train() # </editor-fold desc = "Epoch Initialization"? train_images = not train_images for i, (images, captions, lengths) in enumerate(data_loader): if i == len(data_loader)-1: break # <editor-fold desc = "Training Parameters Initiliazation"? # Set mini-batch dataset images = to_var(images) captions = to_var(captions) # target = pack_padded_sequence(captions, lengths, batch_first=True)[0] # captions, lengths = pad_sequences(captions, lengths) # images = torch.FloatTensor(images) captions = captions.transpose(0,1).unsqueeze(2) lengths = to_var(torch.LongTensor(lengths)) # print(captions.size()) # Forward, Backward and Optimize # img_optim.zero_grad() img_dec_optim.zero_grad() img_enc_optim.zero_grad() # encoder_Img.zero_grad() # decoder_Img.zero_grad() # txt_params.zero_grad() txt_dec_optim.zero_grad() txt_enc_optim.zero_grad() # encoder_Txt.encoder.zero_grad() # decoder_Txt.decoder.zero_grad() # </editor-fold desc = "Training Parameters Initiliazation"? # <editor-fold desc = "Image AE"? # Image Auto_Encoder Forward mu, logvar = encoder_Img(images) Iz = logvar # Iz = reparametrize(mu, logvar) IzI = decoder_Img(mu) img_rc_loss = img_criterion(IzI,images) # </editor-fold desc = "Image AE"? # <editor-fold desc = "Seq2Seq AE"? # Text Auto Encoder Forward # target = target[:-1] # exclude last target from inputs teacher_forcing_ratio = 0.5 encoder_hidden = encoder_Txt.initHidden(args.batch_size) input_length = captions.size(0) target_length = captions.size(0) if cuda: encoder_outputs = Variable(torch.zeros(input_length, args.batch_size, args.hidden_size).cuda()) decoder_outputs = Variable(torch.zeros(input_length, args.batch_size, len(vocab)).cuda()) else: encoder_outputs = Variable(torch.zeros(input_length, args.batch_size, args.hidden_size)) decoder_outputs = Variable(torch.zeros(input_length, args.batch_size, len(vocab))) txt_rc_loss = 0 for ei in range(input_length): encoder_output, encoder_hidden = encoder_Txt( captions[ei,:], encoder_hidden) encoder_outputs[ei] = encoder_output decoder_input = Variable(torch.LongTensor([vocab.word2idx['<start>']])).cuda()\ .repeat(args.batch_size,1) decoder_hidden = encoder_hidden use_teacher_forcing = True #if np.random.random() < teacher_forcing_ratio else False if use_teacher_forcing: # Teacher forcing: Feed the target as the next input for di in range(target_length-1): decoder_output, decoder_hidden = decoder_Txt( decoder_input, decoder_hidden) #, encoder_outputs) # txt_rc_loss += txt_criterion(decoder_output, captions[di].unsqueeze(1)) decoder_outputs[di] = decoder_output decoder_input = captions[di+1] # Teacher forcing else: # Without teacher forcing: use its own predictions as the next input for di in range(target_length-1): decoder_outputs, decoder_hidden = decoder_Txt( decoder_input, decoder_hidden) topv, topi = decoder_output.topk(1) decoder_input = topi.squeeze().detach() # detach from history as input txt_rc_loss += txt_criterion(decoder_output, captions[di]) # if decoder_input.item() == ("<end>"): # break # Check start tokens etc txt_rc_loss, _, _, _ = masked_cross_entropy( decoder_outputs[:target_length-1].transpose(0, 1).contiguous(), captions[1:,:,0].transpose(0, 1).contiguous(), lengths - 1 ) # captions = captions[:-1,:,:] # lengths = lengths - 1 # dec_state = None # Computes Cross-Modal Loss # Tz = encoder_hidden[0] Tz = encoder_output[:,0,:] txt = Tz.narrow(1,0,mask) im = Iz.narrow(1,0,mask) if args.cm_criterion == 'MSE': # cm_loss = cm_criterion(Tz.narrow(1,0,mask), Iz.narrow(1,0,mask)) cm_loss = mse_loss(txt, im) else: cm_loss = cm_criterion(txt, im, \ Variable(torch.ones(im.size(0)).cuda())) # K - Negative Samples k = args.negative_samples neg_rate = (20-epoch)/20 for _ in range(k): if cuda: perm = torch.randperm(args.batch_size).cuda() else: perm = torch.randperm(args.batch_size) # if args.criterion == 'MSE': # cm_loss -= mse_loss(txt, im[perm])/k # else: # cm_loss -= cm_criterion(txt, im[perm], \ # Variable(torch.ones(Tz.narrow(1,0,mask).size(0)).cuda()))/k # sim = (F.cosine_similarity(txt,txt[perm]) - 0.5)/2 if args.cm_criterion == 'MSE': sim = (F.cosine_similarity(txt,txt[perm]) - 1)/(2*k) # cm_loss = cm_criterion(Tz.narrow(1,0,mask), Iz.narrow(1,0,mask)) cm_loss += mse_loss(txt, im[perm], sim) else: cm_loss += neg_rate * cm_criterion(txt, im[perm], \ Variable(-1*torch.ones(txt.size(0)).cuda()))/k # cm_loss = Variable(torch.max(torch.FloatTensor([-0.100]).cuda(), cm_loss.data)) # Computes the loss to be back-propagated img_loss = img_rc_loss * (1 - args.cm_loss_weight) + cm_loss * args.cm_loss_weight txt_loss = txt_rc_loss * (1 - args.cm_loss_weight) + cm_loss * args.cm_loss_weight # txt_loss = txt_rc_loss + 0.1 * cm_loss # img_loss = img_rc_loss + cm_loss txt_losses.update(txt_rc_loss.data[0],args.batch_size) img_losses.update(img_rc_loss.data[0],args.batch_size) cm_losses.update(cm_loss.data[0], args.batch_size) # </editor-fold desc = "Loss accumulation"? # <editor-fold desc = "Back Propagation"> # Half of the times we update one pipeline the others the other one if train_images: # Image Network Training and Backpropagation img_loss.backward() # img_optim.step() img_enc_optim.step() img_dec_optim.step() else: # Text Nextwork Training & Back Propagation txt_loss.backward() # txt_optim.step() txt_enc_optim.step() txt_dec_optim.step() train_images = not train_images # </editor-fold desc = "Back Propagation"> # <editor-fold desc = "Logging"> if i % args.image_save_interval == 0: subdir_path = os.path.join( result_path, str(i / args.image_save_interval) ) if os.path.exists( subdir_path ): pass else: os.makedirs( subdir_path ) for im_idx in range(3): im_or = (images[im_idx].cpu().data.numpy().transpose(1,2,0)/2+.5)*255 im = (IzI[im_idx].cpu().data.numpy().transpose(1,2,0)/2+.5)*255 filename_prefix = os.path.join (subdir_path, str(im_idx)) scipy.misc.imsave( filename_prefix + '_original.A.jpg', im_or) scipy.misc.imsave( filename_prefix + '.A.jpg', im) txt_or = " ".join([vocab.idx2word[c] for c in list(captions[:,im_idx].view(-1).cpu().data)]) txt = " ".join([vocab.idx2word[c] for c in list(decoder_outputs[:,im_idx].view(-1).cpu().data)]) print("Original: ", txt_or) print(txt) # measure elapsed time batch_time.update(time.time() - end) end = time.time() # plot progress bar.suffix = '({batch}/{size}) Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss_Img: {img_l:.3f}| Loss_Txt: {txt_l:.3f} | Loss_CM: {cm_l:.4f}'.format( batch=i, size=len(data_loader), bt=batch_time.avg, total=bar.elapsed_td, eta=bar.eta_td, img_l=img_losses.avg, txt_l=txt_losses.avg, cm_l=cm_losses.avg, ) bar.next() # </editor-fold desc = "Logging"> bar.finish() # <editor-fold desc = "Saving the models"? # Save the models print('\n') print('Saving the models in {}...'.format(model_path)) torch.save(decoder_Img.state_dict(), os.path.join(model_path, 'decoder-img-%d-' %(epoch+1)) + current_date + ".pkl") torch.save(encoder_Img.state_dict(), os.path.join(model_path, 'encoder-img-%d-' %(epoch+1)) + current_date + ".pkl") torch.save(decoder_Txt.state_dict(), os.path.join(model_path, 'decoder-txt-%d-' %(epoch+1)) + current_date + ".pkl") torch.save(encoder_Txt.state_dict(), os.path.join(model_path, 'encoder-txt-%d-' %(epoch+1)) + current_date + ".pkl") # </editor-fold desc = "Saving the models"? if args.validate == "true": validate(encoder_Img, encoder_Txt, val_loader, mask, 10)
def main(): # global args args = parser.parse_args() assert args.criterion in ("MSE","Cosine","Hinge"), 'Invalid Loss Function' cuda = args.cuda if cuda == 'true': cuda = True else: cuda = False # Image preprocessing //ATTENTION # For normalization, see https://github.com/pytorch/vision#models transform = transforms.Compose([ transforms.RandomCrop(args.crop_size), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]) result_path = args.result_path model_path = args.model_path if not os.path.exists(result_path): os.makedirs(result_path) if not os.path.exists(model_path): os.makedirs(model_path) # Load vocabulary wrapper. print('\n') print("\033[94mLoading Vocabulary...") with open(args.vocab_path, 'rb') as f: vocab = pickle.load(f) # Load Embeddings emb_size = args.embedding_size emb_path = args.embedding_path if args.embedding_path[-1]=='/': emb_path += 'glove.6B.' + str(emb_size) + 'd.txt' print("Loading Embeddings...") emb = load_glove_embeddings(emb_path, vocab.word2idx, emb_size) glove_emb = Embeddings(emb_size,len(vocab.word2idx),vocab.word2idx["<pad>"]) glove_emb.word_lut.weight.data.copy_(emb) glove_emb.word_lut.weight.requires_grad = False # glove_emb = nn.Embedding(emb.size(0), emb.size(1)) # glove_emb = embedding(emb.size(0), emb.size(1)) # glove_emb.weight = nn.Parameter(emb) # Freeze weighs # if args.fixed_embeddings == "true": # glove_emb.weight.requires_grad = False # Build data loader print("Building Data Loader For Test Set...") data_loader = get_loader(args.image_dir, args.caption_path, vocab, transform, args.batch_size, shuffle=True, num_workers=args.num_workers) print("Building Data Loader For Validation Set...") val_loader = get_loader(args.valid_dir, args.valid_caption_path, vocab, transform, args.batch_size, shuffle=True, num_workers=args.num_workers) print("Setting up the Networks...") encoder_Txt = TextEncoderOld(glove_emb, num_layers=1, bidirectional=False, hidden_size=args.hidden_size) # decoder_Txt = TextDecoderOld(glove_emb, num_layers=1, bidirectional=False, hidden_size=args.hidden_size) # decoder_Txt = TextDecoder(encoder_Txt, glove_emb) # decoder_Txt = DecoderRNN(glove_emb, hidden_size=args.hidden_size) encoder_Img = ImageEncoder(img_dimension=args.crop_size,feature_dimension= args.hidden_size) # decoder_Img = ImageDecoder(img_dimension=args.crop_size, feature_dimension= args.hidden_size) if cuda: encoder_Txt = encoder_Txt.cuda() encoder_Img = encoder_Img.cuda() for epoch in range(args.num_epochs): # VALIDATION TIME print('\033[92mEPOCH ::: VALIDATION ::: ' + str(epoch + 1)) # Load the models print("Loading the models...") # suffix = '-{}-05-28-13-14.pkl'.format(epoch+1) # mask = 300 prefix = "" suffix = '-{}-05-28-09-23.pkl'.format(epoch+1) # suffix = '-{}-05-28-11-35.pkl'.format(epoch+1) # suffix = '-{}-05-28-16-45.pkl'.format(epoch+1) # suffix = '-{}-05-29-00-28.pkl'.format(epoch+1) # suffix = '-{}-05-29-00-30.pkl'.format(epoch+1) # suffix = '-{}-05-29-01-08.pkl'.format(epoch+1) mask = 200 # suffix = '-{}-05-28-15-39.pkl'.format(epoch+1) # suffix = '-{}-05-29-12-11.pkl'.format(epoch+1) # suffix = '-{}-05-29-12-14.pkl'.format(epoch+1) # suffix = '-{}-05-29-14-24.pkl'.format(epoch+1) #best # suffix = '-{}-05-29-15-43.pkl'.format(epoch+1) date = "06-30-14-22" date = "07-01-12-49" #bad date = "07-01-16-38" date = "07-01-18-16" date = "07-02-15-38" date = "07-08-15-12" prefix = "{}/".format(date) suffix = '-{}-{}.pkl'.format(epoch+1,date) mask = 100 print(suffix) try: encoder_Img.load_state_dict(torch.load(os.path.join(args.model_path, prefix + 'encoder-img' + suffix))) encoder_Txt.load_state_dict(torch.load(os.path.join(args.model_path, prefix + 'encoder-txt' + suffix))) except FileNotFoundError: print("\n\033[91mFile not found...\nTerminating Validation Procedure!") break current_embeddings = np.concatenate( \ (txt_emb.cpu().data.numpy(),\ img_emb.unsqueeze(0).cpu().data.numpy())\ ,0) # current_embeddings = img_emb.data if i: # result_embeddings = torch.cat( \ result_embeddings = np.concatenate( \ (result_embeddings, current_embeddings) \ ,1) else: result_embeddings = current_embeddings # measure elapsed time batch_time.update(time.time() - end) end = time.time() # plot progress bar.suffix = '({batch}/{size}) Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:}'.format( batch=i, size=len(val_loader), bt=batch_time.avg, total=bar.elapsed_td, eta=bar.eta_td, ) bar.next() bar.finish() a = [((result_embeddings[0][i] - result_embeddings[1][i]) ** 2).mean() for i in range(limit*args.batch_size)] print("Validation MSE: ",np.mean(a)) print("Validation MSE: ",np.mean(a)) print("Computing Nearest Neighbors...") i = 0 topk = [] kss = [1,10,50] for k in kss: if i: print("Normalized ") result_embeddings[0] = result_embeddings[0]/result_embeddings[0].sum() result_embeddings[1] = result_embeddings[1]/result_embeddings[1].sum() # k = 5 neighbors = NearestNeighbors(k, metric = 'cosine') neigh = neighbors neigh.fit(result_embeddings[1]) kneigh = neigh.kneighbors(result_embeddings[0], return_distance=False) ks = set() for n in kneigh: ks.update(set(n)) print(len(ks)/result_embeddings.shape[1]) # a = [((result_embeddings[0][i] - result_embeddings[1][i]) ** 2).mean() for i in range(128)] # rs = result_embeddings.sum(2) # a = (((result_embeddings[0][0]- result_embeddings[1][0])**2).mean()) # b = (((result_embeddings[0][0]- result_embeddings[0][34])**2).mean()) topk.append(np.mean([int(i in nn) for i,nn in enumerate(kneigh)])) print("Top-{k:},{k2:},{k3:} accuracy for Image Retrieval:\n\n\t\033[95m {tpk: .3f}% \t {tpk2: .3f}% \t {tpk3: .3f}% \n".format( k=kss[0], k2=kss[1], k3=kss[2], tpk= 100*topk[0], tpk2= 100*topk[1], tpk3= 100*topk[2]))
def main(): # global args args = parser.parse_args() # <editor-fold desc="Initialization"> if args.comment == "test": print("WARNING: name is test!!!\n\n") # now = datetime.datetime.now() # current_date = now.strftime("%m-%d-%H-%M") assert args.text_criterion in ("MSE", "Cosine", "Hinge", "NLLLoss"), 'Invalid Loss Function' assert args.cm_criterion in ("MSE", "Cosine", "Hinge"), 'Invalid Loss Function' assert args.common_emb_ratio <= 1.0 and args.common_emb_ratio >= 0 mask = int(args.common_emb_ratio * args.hidden_size) cuda = args.cuda if cuda == 'true': cuda = True else: cuda = False if args.load_model == "NONE": keep_loading = False # model_path = args.model_path + current_date + "/" model_path = args.model_path + args.comment + "/" else: keep_loading = True model_path = args.model_path + args.load_model + "/" result_path = args.result_path if result_path == "NONE": result_path = model_path + "results/" if not os.path.exists(result_path): os.makedirs(result_path) if not os.path.exists(model_path): os.makedirs(model_path) #</editor-fold> # <editor-fold desc="Image Preprocessing"> # Image preprocessing //ATTENTION # For normalization, see https://github.com/pytorch/vision#models transform = transforms.Compose([ transforms.RandomCrop(args.crop_size), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) ]) inv_normalize = transforms.Normalize( mean=[-0.485 / 0.229, -0.456 / 0.224, -0.406 / 0.255], std=[1 / 0.229, 1 / 0.224, 1 / 0.255]) #</editor-fold> # <editor-fold desc="Creating Embeddings"> # Load vocabulary wrapper. print("Loading Vocabulary...") with open(args.vocab_path, 'rb') as f: vocab = pickle.load(f) # Load Embeddings emb_size = args.word_embedding_size emb_path = args.embedding_path if args.embedding_path[-1] == '/': emb_path += 'glove.6B.' + str(emb_size) + 'd.txt' print("Loading Embeddings...") emb = load_glove_embeddings(emb_path, vocab.word2idx, emb_size) # glove_emb = Embeddings(emb_size,len(vocab.word2idx),vocab.word2idx["<pad>"]) # glove_emb.word_lut.weight.data.copy_(emb) # glove_emb.word_lut.weight.requires_grad = False glove_emb = nn.Embedding(emb.size(0), emb.size(1)) # glove_emb = embedding(emb.size(0), emb.size(1)) # glove_emb.weight = nn.Parameter(emb) # Freeze weighs # if args.fixed_embeddings == "true": # glove_emb.weight.requires_grad = False # </editor-fold> # <editor-fold desc="Data-Loaders"> # Build data loader print("Building Data Loader For Test Set...") data_loader = get_loader(args.image_dir, args.caption_path, vocab, transform, args.batch_size, shuffle=True, num_workers=args.num_workers) print("Building Data Loader For Validation Set...") val_loader = get_loader(args.valid_dir, args.valid_caption_path, vocab, transform, args.batch_size, shuffle=True, num_workers=args.num_workers) # </editor-fold> # <editor-fold desc="Network Initialization"> print("Setting up the Networks...") vae_Txt = SentenceVAE(glove_emb, len(vocab), hidden_size=args.hidden_size, latent_size=args.latent_size, batch_size=args.batch_size) vae_Img = ImgVAE(img_dimension=args.crop_size, hidden_size=args.hidden_size, latent_size=args.latent_size) if cuda: vae_Txt = vae_Txt.cuda() vae_Img = vae_Img.cuda() # </editor-fold> # <editor-fold desc="Losses"> # Losses and Optimizers print("Setting up the Objective Functions...") img_criterion = nn.MSELoss() # txt_criterion = nn.MSELoss(size_average=True) if args.text_criterion == 'MSE': txt_criterion = nn.MSELoss() elif args.text_criterion == "Cosine": txt_criterion = nn.CosineEmbeddingLoss(size_average=False) elif args.text_criterion == "NLLLoss": txt_criterion = nn.NLLLoss() else: txt_criterion = nn.HingeEmbeddingLoss(size_average=False) if args.cm_criterion == 'MSE': cm_criterion = nn.MSELoss() elif args.cm_criterion == "Cosine": cm_criterion = nn.CosineEmbeddingLoss() else: cm_criterion = nn.HingeEmbeddingLoss() if cuda: img_criterion = img_criterion.cuda() txt_criterion = txt_criterion.cuda() cm_criterion = cm_criterion.cuda() # txt_criterion = nn.CrossEntropyLoss() # </editor-fold> # <editor-fold desc="Optimizers"> print("Setting up the Optimizers...") img_optim = optim.Adam(vae_Img.parameters(), lr=args.learning_rate, betas=(0.5, 0.999), weight_decay=0.00001) txt_optim = optim.Adam(vae_Txt.parameters(), lr=args.learning_rate, betas=(0.5, 0.999), weight_decay=0.00001) # </editor-fold desc="Optimizers"> train_images = True # Reverse 2 step = 0 for epoch in range(args.num_epochs): # <editor-fold desc = "Epoch Initialization"? # TRAINING TIME print('EPOCH ::: TRAINING ::: ' + str(epoch + 1)) batch_time = AverageMeter() txt_losses = AverageMeter() img_losses = AverageMeter() cm_losses = AverageMeter() end = time.time() bar = Bar('Training Net', max=len(data_loader)) if keep_loading: suffix = "-" + str(epoch) + "-" + args.load_model + ".pkl" try: vae_Img.load_state_dict( torch.load( os.path.join(args.model_path, 'vae-img' + suffix))) vae_Txt.load_state_dict( torch.load( os.path.join(args.model_path, 'vae-txt' + suffix))) except FileNotFoundError: print("Didn't find any models switching to training") keep_loading = False if not keep_loading: # Set training mode vae_Txt.train() vae_Img.train() # </editor-fold desc = "Epoch Initialization"? # train_images = not train_images for i, (images, captions, lengths) in enumerate(data_loader): if i == len(data_loader) - 1: break # <editor-fold desc = "Training Parameters Initiliazation"? # Set mini-batch dataset images = to_var(images) captions = to_var(captions) # captions = captions.transpose(0,1).unsqueeze(2) lengths = to_var( torch.LongTensor(lengths)) # print(captions.size()) # Forward, Backward and Optimize img_optim.zero_grad() txt_optim.zero_grad() # </editor-fold desc = "Training Parameters Initiliazation"? # <editor-fold desc = "Forward passes"? img_out, img_mu, img_logv, img_z = vae_Img(images) txt_out, txt_mu, txt_logv, txt_z = vae_Txt(captions, lengths) img_rc_loss = img_vae_loss( img_out, images, img_mu, img_logv) / (args.batch_size * args.crop_size**2) NLL_loss, KL_loss, KL_weight = seq_vae_loss( txt_out, captions, lengths, txt_mu, txt_logv, "logistic", step, 0.0025, 2500) txt_rc_loss = (NLL_loss + KL_weight * KL_loss) / torch.sum(lengths).float() cm_loss = crossmodal_loss(txt_z, img_z, mask, args.cm_criterion, cm_criterion, args.negative_samples, epoch) # cm_loss += crossmodal_loss(txt_logv, img_logv, mask, # args.cm_criterion, cm_criterion, # args.negative_samples, epoch) # Computes the loss to be back-propagated img_loss = img_rc_loss * ( 1 - args.cm_loss_weight) + cm_loss * args.cm_loss_weight txt_loss = txt_rc_loss * ( 1 - args.cm_loss_weight) + cm_loss * args.cm_loss_weight # txt_loss = txt_rc_loss + cm_loss * args.cm_loss_weight # img_loss = img_rc_loss + cm_loss * args.cm_loss_weight txt_losses.update(txt_rc_loss.data[0], args.batch_size) img_losses.update(img_rc_loss.data[0], args.batch_size) cm_losses.update(cm_loss.data[0], args.batch_size) # </editor-fold desc = "Loss accumulation"? # <editor-fold desc = "Back Propagation"> # Half of the times we update one pipeline the others the other one if train_images: # Image Network Training and Backpropagation img_loss.backward() img_optim.step() else: # Text Nextwork Training & Back Propagation txt_loss.backward() txt_optim.step() step += 1 # train_images = not train_images # </editor-fold desc = "Back Propagation"> # <editor-fold desc = "Logging"> if i % args.image_save_interval == 0: subdir_path = os.path.join( result_path, str(i / args.image_save_interval)) if os.path.exists(subdir_path): pass else: os.makedirs(subdir_path) for im_idx in range(3): # im_or = (inv_normalize([im_idx]).cpu().data.numpy().transpose(1,2,0))*255 # im = (inv_normalize([im_idx]).cpu().data.numpy().transpose(1,2,0))*255 im_or = (images[im_idx].cpu().data.numpy().transpose( 1, 2, 0) / 2 + .5) * 255 im = (img_out[im_idx].cpu().data.numpy().transpose( 1, 2, 0) / 2 + .5) * 255 # im = img_out[im_idx].cpu().data.numpy().transpose(1,2,0)*255 filename_prefix = os.path.join(subdir_path, str(im_idx)) scipy.misc.imsave(filename_prefix + '_original.A.jpg', im_or) scipy.misc.imsave(filename_prefix + '.A.jpg', im) txt_or = " ".join([ vocab.idx2word[c] for c in captions[im_idx].cpu().data.numpy() ]) _, generated = torch.topk(txt_out[im_idx], 1) txt = " ".join([ vocab.idx2word[c] for c in generated[:, 0].cpu().data.numpy() ]) with open(filename_prefix + "_captions.txt", "w") as text_file: text_file.write("Original: %s\n" % txt_or) text_file.write("Generated: %s" % txt) # measure elapsed time batch_time.update(time.time() - end) end = time.time() # plot progress bar.suffix = '({batch}/{size}) Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss_Img: {img_l:.3f}| Loss_Txt: {txt_l:.3f} | Loss_CM: {cm_l:.4f}'.format( batch=i, size=len(data_loader), bt=batch_time.avg, total=bar.elapsed_td, eta=bar.eta_td, img_l=img_losses.avg, txt_l=txt_losses.avg, cm_l=cm_losses.avg, ) bar.next() # </editor-fold desc = "Logging"> bar.finish() # <editor-fold desc = "Saving the models"? # Save the models print('\n') print('Saving the models in {}...'.format(model_path)) torch.save( vae_Img.state_dict(), os.path.join(model_path, 'vae-img-%d-' % (epoch + 1)) + ".pkl") torch.save( vae_Txt.state_dict(), os.path.join(model_path, 'vae-txt-%d-' % (epoch + 1)) + ".pkl") # </editor-fold desc = "Saving the models"? if args.validate == "true": validate(vae_Img, vae_Txt, val_loader, mask, 10)