Exemplo n.º 1
0
def main():
    # global args
    args = parser.parse_args()

    # <editor-fold desc="Initialization">
    if args.comment == "test":
        print("WARNING: name is test!!!\n\n")

    # now = datetime.datetime.now()
    # current_date = now.strftime("%m-%d-%H-%M")

    assert args.text_criterion in ("MSE", "Cosine", "Hinge",
                                   "NLLLoss"), 'Invalid Loss Function'
    assert args.cm_criterion in ("MSE", "Cosine",
                                 "Hinge"), 'Invalid Loss Function'

    assert args.common_emb_ratio <= 1.0 and args.common_emb_ratio >= 0

    mask = int(args.common_emb_ratio * args.hidden_size)

    cuda = args.cuda
    if cuda == 'true':
        cuda = True
    else:
        cuda = False

    if args.load_model == "NONE":
        keep_loading = False
        # model_path = args.model_path + current_date + "/"
        model_path = args.model_path + args.comment + "/"
    else:
        keep_loading = True
        model_path = args.model_path + args.load_model + "/"

    result_path = args.result_path
    if result_path == "NONE":
        result_path = model_path + "results/"

    if not os.path.exists(result_path):
        os.makedirs(result_path)
    if not os.path.exists(model_path):
        os.makedirs(model_path)
    #</editor-fold>

    # <editor-fold desc="Image Preprocessing">

    # Image preprocessing //ATTENTION
    # For normalization, see https://github.com/pytorch/vision#models
    transform = transforms.Compose([
        transforms.RandomCrop(args.crop_size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])

    #</editor-fold>

    # <editor-fold desc="Creating Embeddings">

    # Load vocabulary wrapper.
    print("Loading Vocabulary...")
    with open(args.vocab_path, 'rb') as f:
        vocab = pickle.load(f)

    # Load Embeddings
    emb_size = args.word_embedding_size
    emb_path = args.embedding_path
    if args.embedding_path[-1] == '/':
        emb_path += 'glove.6B.' + str(emb_size) + 'd.txt'

    print("Loading Embeddings...")
    emb = load_glove_embeddings(emb_path, vocab.word2idx, emb_size)

    glove_emb = nn.Embedding(emb.size(0), emb.size(1))

    # Freeze weighs
    if args.fixed_embeddings == "true":
        glove_emb.weight.requires_grad = False

    # </editor-fold>

    # <editor-fold desc="Data-Loaders">

    # Build data loader
    print("Building Data Loader For Test Set...")
    data_loader = get_loader(args.image_dir,
                             args.caption_path,
                             vocab,
                             transform,
                             args.batch_size,
                             shuffle=True,
                             num_workers=args.num_workers)

    print("Building Data Loader For Validation Set...")
    val_loader = get_loader(args.valid_dir,
                            args.valid_caption_path,
                            vocab,
                            transform,
                            args.batch_size,
                            shuffle=True,
                            num_workers=args.num_workers)

    # </editor-fold>

    # <editor-fold desc="Network Initialization">

    print("Setting up the Networks...")
    coupled_vae = CoupledVAE(glove_emb,
                             len(vocab),
                             hidden_size=args.hidden_size,
                             latent_size=args.latent_size,
                             batch_size=args.batch_size)

    if cuda:
        coupled_vae = coupled_vae.cuda()

    # </editor-fold>

    # </editor-fold>

    # <editor-fold desc="Optimizers">
    print("Setting up the Optimizers...")

    vae_optim = optim.Adam(coupled_vae.parameters(),
                           lr=args.learning_rate,
                           betas=(0.5, 0.999),
                           weight_decay=0.00001)

    # </editor-fold desc="Optimizers">

    train_swapped = False  # Reverse 2

    step = 0

    with open(os.path.join(result_path, "losses.csv"), "w") as text_file:
        text_file.write("Epoch, Img, Txt, CM\n")

    for epoch in range(args.num_epochs):

        # <editor-fold desc = "Epoch Initialization"?

        # TRAINING TIME
        print('EPOCH ::: TRAINING ::: ' + str(epoch + 1))
        batch_time = AverageMeter()
        txt_losses = AverageMeter()
        img_losses = AverageMeter()
        cm_losses = AverageMeter()
        end = time.time()

        bar = Bar('Training Net', max=len(data_loader))

        if keep_loading:
            suffix = "-" + str(epoch) + "-" + args.load_model + ".pkl"
            try:
                coupled_vae.load_state_dict(
                    torch.load(
                        os.path.join(args.model_path, 'coupled_vae' + suffix)))
            except FileNotFoundError:
                print("Didn't find any models switching to training")
                keep_loading = False

        if not keep_loading:

            # Set training mode
            coupled_vae.train()

            # </editor-fold desc = "Epoch Initialization"?

            train_swapped = not train_swapped
            for i, (images, captions, lengths) in enumerate(data_loader):

                if i == len(data_loader) - 1:
                    break

                images = to_var(images)
                captions = to_var(captions)
                lengths = to_var(
                    torch.LongTensor(lengths))  # print(captions.size())

                # Forward, Backward and Optimize
                vae_optim.zero_grad()


                img_out, img_mu, img_logv, img_z, txt_out, txt_mu, txt_logv, txt_z = \
                                                                 coupled_vae(images, captions, lengths, train_swapped)

                img_rc_loss = img_vae_loss(
                    img_out, images, img_mu,
                    img_logv) / (args.batch_size * args.crop_size**2)

                NLL_loss, KL_loss, KL_weight = seq_vae_loss(
                    txt_out, captions, lengths, txt_mu, txt_logv, "logistic",
                    step, 0.0025, 2500)
                txt_rc_loss = (NLL_loss + KL_weight *
                               KL_loss) / torch.sum(lengths).float()

                txt_losses.update(txt_rc_loss.data[0], args.batch_size)
                img_losses.update(img_rc_loss.data[0], args.batch_size)

                loss = img_rc_loss + txt_rc_loss

                loss.backward()
                vae_optim.step()
                step += 1

                if i % args.image_save_interval == 0:
                    subdir_path = os.path.join(
                        result_path, str(i / args.image_save_interval))

                    if os.path.exists(subdir_path):
                        pass
                    else:
                        os.makedirs(subdir_path)

                    for im_idx in range(3):
                        # im_or = (images[im_idx].cpu().data.numpy().transpose(1,2,0))*255
                        # im = (img_out[im_idx].cpu().data.numpy().transpose(1,2,0))*255
                        im_or = (images[im_idx].cpu().data.numpy().transpose(
                            1, 2, 0) / 2 + .5) * 255
                        im = (img_out[im_idx].cpu().data.numpy().transpose(
                            1, 2, 0) / 2 + .5) * 255
                        # im = img_out[im_idx].cpu().data.numpy().transpose(1,2,0)*255

                        filename_prefix = os.path.join(subdir_path,
                                                       str(im_idx))
                        scipy.misc.imsave(filename_prefix + '_original.A.jpg',
                                          im_or)
                        scipy.misc.imsave(filename_prefix + '.A.jpg', im)

                        txt_or = " ".join([
                            vocab.idx2word[c]
                            for c in captions[im_idx].cpu().data.numpy()
                        ])
                        _, generated = torch.topk(txt_out[im_idx], 1)
                        txt = " ".join([
                            vocab.idx2word[c]
                            for c in generated[:, 0].cpu().data.numpy()
                        ])

                        with open(filename_prefix + "_captions.txt",
                                  "w") as text_file:
                            text_file.write("Epoch %d\n" % epoch)
                            text_file.write("Original: %s\n" % txt_or)
                            text_file.write("Generated: %s" % txt)

                # measure elapsed time
                batch_time.update(time.time() - end)
                end = time.time()

                # plot progress
                bar.suffix = '({batch}/{size}) Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss_Img: {img_l:.3f}| Loss_Txt: {txt_l:.3f} | Loss_CM: {cm_l:.4f}'.format(
                    batch=i,
                    size=len(data_loader),
                    bt=batch_time.avg,
                    total=bar.elapsed_td,
                    eta=bar.eta_td,
                    img_l=img_losses.avg,
                    txt_l=txt_losses.avg,
                    cm_l=cm_losses.avg,
                )
                bar.next()

            # </editor-fold desc = "Logging">

            bar.finish()

            with open(os.path.join(result_path, "losses.csv"),
                      "a") as text_file:
                text_file.write("{}, {}, {}, {}\n".format(
                    epoch, img_losses.avg, txt_losses.avg, cm_losses.avg))

            # <editor-fold desc = "Saving the models"?
            # Save the models
            print('\n')
            print('Saving the models in {}...'.format(model_path))
            torch.save(
                coupled_vae.state_dict(),
                os.path.join(model_path, 'coupled_vae' % (epoch + 1)) + ".pkl")
Exemplo n.º 2
0
def main():
    print("Initializing...")
    # global args
    args = parser.parse_args()

    now = datetime.datetime.now()
    current_date = now.strftime("%m-%d-%H-%M")

    assert args.text_criterion in ("MSE", "Cosine",
                                   "Hinge"), 'Invalid Loss Function'
    assert args.cm_criterion in ("MSE", "Cosine",
                                 "Hinge"), 'Invalid Loss Function'

    mask = args.common_emb_size
    assert mask <= args.hidden_size

    cuda = args.cuda
    if cuda == 'true':
        cuda = True
    else:
        cuda = False

    # Image preprocessing //ATTENTION
    # For normalization, see https://github.com/pytorch/vision#models
    transform = transforms.Compose([
        transforms.RandomCrop(args.crop_size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])

    result_path = args.result_path
    model_path = args.model_path + current_date + "/"

    if not os.path.exists(result_path):
        os.makedirs(result_path)
    if not os.path.exists(model_path):
        print("Creating model path on", model_path)
        os.makedirs(model_path)

    # Load vocabulary wrapper.
    print("Loading Vocabulary...")
    with open(args.vocab_path, 'rb') as f:
        vocab = pickle.load(f)

    # Load Embeddings
    emb_size = args.embedding_size
    emb_path = args.embedding_path
    if args.embedding_path[-1] == '/':
        emb_path += 'glove.6B.' + str(emb_size) + 'd.txt'

    print("Loading Embeddings...")
    emb = load_glove_embeddings(emb_path, vocab.word2idx, emb_size)

    glove_emb = Embeddings(emb_size, len(vocab.word2idx),
                           vocab.word2idx["<pad>"])
    glove_emb.word_lut.weight.data.copy_(emb)
    glove_emb.word_lut.weight.requires_grad = False

    # glove_emb = nn.Embedding(emb.size(0), emb.size(1))
    # glove_emb = embedding(emb.size(0), emb.size(1))
    # glove_emb.weight = nn.Parameter(emb)

    # Freeze weighs
    # if args.fixed_embeddings == "true":
    # glove_emb.weight.requires_grad = False

    # Build data loader
    print("Building Data Loader For Test Set...")
    data_loader = get_loader(args.image_dir,
                             args.caption_path,
                             vocab,
                             transform,
                             args.batch_size,
                             shuffle=True,
                             num_workers=args.num_workers)

    print("Building Data Loader For Validation Set...")
    val_loader = get_loader(args.valid_dir,
                            args.valid_caption_path,
                            vocab,
                            transform,
                            args.batch_size,
                            shuffle=True,
                            num_workers=args.num_workers)

    print("Setting up the Networks...")

    encoder_Img = ImageEncoder(img_dimension=args.crop_size,
                               feature_dimension=args.hidden_size)
    decoder_Img = ImageDecoder(img_dimension=args.crop_size,
                               feature_dimension=args.hidden_size)

    if cuda:
        encoder_Img = encoder_Img.cuda()
        decoder_Img = decoder_Img.cuda()

    # Losses and Optimizers
    print("Setting up the Objective Functions...")
    img_criterion = nn.MSELoss()
    # txt_criterion = nn.MSELoss(size_average=True)

    if cuda:
        img_criterion = img_criterion.cuda()
    # txt_criterion = nn.CrossEntropyLoss()

    #     gen_params = chain(generator_A.parameters(), generator_B.parameters())
    print("Setting up the Optimizers...")
    # img_params = chain(decoder_Img.parameters(), encoder_Img.parameters())
    img_params = list(decoder_Img.parameters()) + list(
        encoder_Img.parameters())

    # ATTENTION: Check betas and weight decay
    # ATTENTION: Check why valid_params fails on image networks with out of memory error

    img_optim = optim.Adam(
        img_params, lr=0.001)  #,betas=(0.5, 0.999), weight_decay=0.00001)
    # img_enc_optim = optim.Adam(encoder_Img.parameters(), lr=args.learning_rate)#betas=(0.5, 0.999), weight_decay=0.00001)
    # img_dec_optim = optim.Adam(decoder_Img.parameters(), lr=args.learning_rate)#betas=(0.5,0.999), weight_decay=0.00001)

    train_images = False  # Reverse 2
    for epoch in range(args.num_epochs):

        # TRAINING TIME
        print('EPOCH ::: TRAINING ::: ' + str(epoch + 1))
        batch_time = AverageMeter()
        img_losses = AverageMeter()
        txt_losses = AverageMeter()
        cm_losses = AverageMeter()
        end = time.time()

        bar = Bar('Training Net', max=len(data_loader))

        # Set training mode
        encoder_Img.train()
        decoder_Img.train()

        train_images = True
        for i, (images, captions, lengths) in enumerate(data_loader):
            # ATTENTION REMOVE
            if i == 6450:
                break

            # Set mini-batch dataset
            images = to_var(images)
            captions = to_var(captions)

            # target = pack_padded_sequence(captions, lengths, batch_first=True)[0]
            # captions, lengths = pad_sequences(captions, lengths)
            # images = torch.FloatTensor(images)

            captions = captions.transpose(0, 1).unsqueeze(2)
            lengths = torch.LongTensor(lengths)  # print(captions.size())

            # Forward, Backward and Optimize
            # img_optim.zero_grad()
            # img_dec_optim.zero_grad()
            # img_enc_optim.zero_grad()
            encoder_Img.zero_grad()
            decoder_Img.zero_grad()

            # txt_params.zero_grad()
            # txt_dec_optim.zero_grad()
            # txt_enc_optim.zero_grad()

            # Image Auto_Encoder Forward

            img_encoder_outputs, Iz = encoder_Img(images)

            IzI = decoder_Img(img_encoder_outputs)

            img_rc_loss = img_criterion(IzI, images)

            # Text Auto Encoder Forward

            # target = target[:-1] # exclude last target from inputs

            img_loss = img_rc_loss

            img_losses.update(img_rc_loss.data[0], args.batch_size)
            txt_losses.update(0, args.batch_size)
            cm_losses.update(0, args.batch_size)

            # Image Network Training and Backpropagation

            img_loss.backward()
            img_optim.step()

            if i % args.image_save_interval == 0:
                subdir_path = os.path.join(result_path,
                                           str(i / args.image_save_interval))

                if os.path.exists(subdir_path):
                    pass
                else:
                    os.makedirs(subdir_path)

                for im_idx in range(3):
                    im_or = (images[im_idx].cpu().data.numpy().transpose(
                        1, 2, 0) / 2 + .5) * 255
                    im = (IzI[im_idx].cpu().data.numpy().transpose(1, 2, 0) / 2
                          + .5) * 255

                    filename_prefix = os.path.join(subdir_path, str(im_idx))
                    scipy.misc.imsave(filename_prefix + '_original.A.jpg',
                                      im_or)
                    scipy.misc.imsave(filename_prefix + '.A.jpg', im)

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            # plot progress
            bar.suffix = '({batch}/{size}) Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss_Img: {img_l:.3f}| Loss_Txt: {txt_l:.3f} | Loss_CM: {cm_l:.4f}'.format(
                batch=i,
                size=len(data_loader),
                bt=batch_time.avg,
                total=bar.elapsed_td,
                eta=bar.eta_td,
                img_l=img_losses.avg,
                txt_l=txt_losses.avg,
                cm_l=cm_losses.avg,
            )
            bar.next()
        bar.finish()

        # Save the models
        print('\n')
        print('Saving the models in {}...'.format(model_path))
        torch.save(
            decoder_Img.state_dict(),
            os.path.join(model_path, 'decoder-img-%d-' % (epoch + 1)) +
            current_date + ".pkl")
        torch.save(
            encoder_Img.state_dict(),
            os.path.join(model_path, 'encoder-img-%d-' % (epoch + 1)) +
            current_date + ".pkl")
Exemplo n.º 3
0
def main():
    # global args
    args = parser.parse_args()

    # <editor-fold desc="Initialization">
    if args.comment == "NONE":
        args.comment = args.method

    validate = args.validate == "true"

    if args.method == "coupled_vae_gan":
        trainer = coupled_vae_gan_trainer.coupled_vae_gan_trainer
    elif args.method == "coupled_vae":
        trainer = coupled_vae_trainer.coupled_vae_trainer
    elif args.method == "wgan":
        trainer = wgan_trainer.wgan_trainer
    elif args.method == "seq_wgan":
        trainer = seq_wgan_trainer.wgan_trainer
    elif args.method == "skip_thoughts":
        trainer = skipthoughts_vae_gan_trainer.coupled_vae_gan_trainer
    else:
        assert False, "Invalid method"

    # now = datetime.datetime.now()
    # current_date = now.strftime("%m-%d-%H-%M")

    assert args.text_criterion in ("MSE", "Cosine", "Hinge",
                                   "NLLLoss"), 'Invalid Loss Function'
    assert args.cm_criterion in ("MSE", "Cosine",
                                 "Hinge"), 'Invalid Loss Function'

    assert args.common_emb_ratio <= 1.0 and args.common_emb_ratio >= 0

    #</editor-fold>

    # <editor-fold desc="Image Preprocessing">

    # Image preprocessing //ATTENTION
    # For normalization, see https://github.com/pytorch/vision#models
    transform = transforms.Compose([
        transforms.RandomCrop(args.crop_size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((.5, .5, .5), (.5, .5, .5))
        # transforms.Normalize((0.485, 0.456, 0.406),
        #                      (0.229, 0.224, 0.225))
    ])

    #</editor-fold>

    # <editor-fold desc="Creating Embeddings">
    if args.dataset != "coco":
        args.vocab_path = "./data/cub_vocab.pkl"

    # Load vocabulary wrapper.
    print("Loading Vocabulary...")
    with open(args.vocab_path, 'rb') as f:
        vocab = pickle.load(f)

    # Load Embeddings
    emb_size = args.word_embedding_size
    emb_path = args.embedding_path
    if args.embedding_path[-1] == '/':
        emb_path += 'glove.6B.' + str(emb_size) + 'd.txt'

    print("Loading Embeddings...")

    use_glove = args.use_glove == "true"
    if use_glove:
        emb = load_glove_embeddings(emb_path, vocab.word2idx, emb_size)
        word_emb = nn.Embedding(emb.size(0), emb.size(1))
        word_emb.weight = nn.Parameter(emb)
    else:
        word_emb = nn.Embedding(len(vocab), emb_size)

    # Freeze weighs
    if args.fixed_embeddings == "true":
        word_emb.weight.requires_grad = True

    # </editor-fold>

    # <editor-fold desc="Data-Loaders">

    # Build data loader
    print("Building Data Loader For Test Set...")
    if args.dataset == 'coco':
        data_loader = get_loader(args.image_dir,
                                 args.caption_path,
                                 vocab,
                                 transform,
                                 args.batch_size,
                                 shuffle=True,
                                 num_workers=args.num_workers)

        print("Building Data Loader For Validation Set...")
        val_loader = get_loader(args.valid_dir,
                                args.valid_caption_path,
                                vocab,
                                transform,
                                args.batch_size,
                                shuffle=True,
                                num_workers=args.num_workers)

    else:
        data_path = "data/cub.h5"
        dataset = Text2ImageDataset(data_path,
                                    split=0,
                                    vocab=vocab,
                                    transform=transform)
        data_loader = DataLoader(dataset,
                                 batch_size=args.batch_size,
                                 shuffle=True,
                                 num_workers=args.num_workers,
                                 collate_fn=collate_fn)

        dataset_val = Text2ImageDataset(data_path,
                                        split=1,
                                        vocab=vocab,
                                        transform=transform)
        val_loader = DataLoader(dataset_val,
                                batch_size=args.batch_size,
                                shuffle=True,
                                num_workers=args.num_workers,
                                collate_fn=collate_fn)

    # </editor-fold>            txt_rc_loss = self.networks["coupled_vae"].text_reconstruction_loss(captions, txt2txt_out, lengths)

    # <editor-fold desc="Network Initialization">

    print("Setting up the trainer...")
    model_trainer = trainer(args, word_emb, vocab)

    #  <\editor-fold desc="Network Initialization">

    for epoch in range(args.num_epochs):

        # <editor-fold desc = "Epoch Initialization"?

        # TRAINING TIME
        print('EPOCH ::: TRAINING ::: ' + str(epoch + 1))
        batch_time = AverageMeter()
        end = time.time()

        bar = Bar(args.method if args.comment == "NONE" else args.method +
                  "/" + args.comment,
                  max=len(data_loader))

        model_trainer.set_train_models()
        model_trainer.create_losses_meter(model_trainer.losses)

        for i, (images, captions, lengths) in enumerate(data_loader):
            if model_trainer.load_models(epoch):
                break

            # if i == 1:
            if i == len(data_loader) - 1:
                break

            images = to_var(images)
            # captions = to_var(captions[:,1:])
            captions = to_var(captions)
            # lengths = to_var(torch.LongTensor(lengths) - 1)            # print(captions.size())
            lengths = to_var(
                torch.LongTensor(lengths))  # print(captions.size())

            model_trainer.forward(epoch, images, captions, lengths,
                                  not i % args.image_save_interval)

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if not model_trainer.iteration % args.log_step:
                # plot progress
                bar.suffix = bcolors.HEADER
                # bar.suffix += '({batch}/{size}) Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:}\n'.format(
                bar.suffix += '({batch}/{size}) Iter: {bt:} | Time: {total:}-{eta:}\n'.format(
                    batch=i,
                    size=len(data_loader),
                    # bt=batch_time.val,
                    bt=model_trainer.iteration,
                    total=bar.elapsed_td,
                    eta=bar.eta_td,
                )
                bar.suffix += bcolors.ENDC

                cnt = 0
                for l_name, l_value in sorted(model_trainer.losses.items(),
                                              key=lambda x: x[0]):
                    cnt += 1
                    bar.suffix += ' | {name}: {val:.3f}'.format(
                        name=l_name,
                        val=l_value.avg,
                    )
                    if not cnt % 5:
                        bar.suffix += "\n"

                bar.next()

        # </editor-fold desc = "Logging">

        bar.finish()

        if validate:
            print('EPOCH ::: VALIDATION ::: ' + str(epoch + 1))
            batch_time = AverageMeter()
            end = time.time()
            barName = args.method if args.comment == "NONE" else args.method + "/" + args.comment
            barName = "VAL:" + barName
            bar = Bar(barName, max=len(val_loader))

            model_trainer.set_eval_models()
            model_trainer.create_metrics_meter(model_trainer.metrics)

            for i, (images, captions, lengths) in enumerate(val_loader):
                # if not model_trainer.keep_loading and not model_trainer.iteration % args.model:
                #     model_trainer.save_models(epoch)

                if i == len(val_loader) - 1:
                    break

                images = to_var(images)
                captions = to_var(captions[:, 1:])
                # lengths = to_var(torch.LongTensor(lengths - 1))            # print(captions.size())

                model_trainer.evaluate(epoch, images, captions, lengths,
                                       i == 0)

                # measure elapsed time
                batch_time.update(time.time() - end)
                end = time.time()

                # plot progress
                bar.suffix = bcolors.HEADER
                # bar.suffix += '({batch}/{size}) Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:}\n'.format(
                bar.suffix += '({batch}/{size}) Iter: {bt:} | Time: {total:}-{eta:}\n'.format(
                    batch=i,
                    size=len(val_loader),
                    # bt=batch_time.val,
                    bt=model_trainer.iteration,
                    total=bar.elapsed_td,
                    eta=bar.eta_td,
                )
                bar.suffix += bcolors.ENDC

                cnt = 0
                for l_name, l_value in sorted(model_trainer.metrics.items(),
                                              key=lambda x: x[0]):
                    cnt += 1
                    bar.suffix += ' | {name}: {val:.3f}'.format(
                        name=l_name,
                        val=l_value.avg,
                    )
                    if not cnt % 5:
                        bar.suffix += "\n"

                bar.next()

            bar.finish()

        # model_trainer.validate(val_loader)
    model_trainer.save_models(-1)
Exemplo n.º 4
0
def main():
    # global args
    args = parser.parse_args()

    # <editor-fold desc="Initialization">
    if args.comment == "test":
        print("WARNING: name is test!!!\n\n")

    # now = datetime.datetime.now()
    # current_date = now.strftime("%m-%d-%H-%M")

    assert args.text_criterion in ("MSE", "Cosine", "Hinge",
                                   "NLLLoss"), 'Invalid Loss Function'
    assert args.cm_criterion in ("MSE", "Cosine",
                                 "Hinge"), 'Invalid Loss Function'

    assert args.common_emb_ratio <= 1.0 and args.common_emb_ratio >= 0

    #</editor-fold>

    # <editor-fold desc="Image Preprocessing">

    # Image preprocessing //ATTENTION
    # For normalization, see https://github.com/pytorch/vision#models
    transform = transforms.Compose([
        transforms.RandomCrop(args.crop_size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        # transforms.Normalize((.5,.5,.5),
        #                      (.5, .5, .5))
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])

    #</editor-fold>

    # <editor-fold desc="Creating Embeddings">

    # Load vocabulary wrapper.
    print("Loading Vocabulary...")
    with open(args.vocab_path, 'rb') as f:
        vocab = pickle.load(f)

    # Load Embeddings
    emb_size = args.word_embedding_size
    emb_path = args.embedding_path
    if args.embedding_path[-1] == '/':
        emb_path += 'glove.6B.' + str(emb_size) + 'd.txt'

    print("Loading Embeddings...")
    emb = load_glove_embeddings(emb_path, vocab.word2idx, emb_size)

    glove_emb = nn.Embedding(emb.size(0), emb.size(1))

    # Freeze weighs
    if args.fixed_embeddings == "true":
        glove_emb.weight.requires_grad = False

    # </editor-fold>

    # <editor-fold desc="Data-Loaders">

    # Build data loader
    print("Building Data Loader For Test Set...")
    data_loader = get_loader(args.image_dir,
                             args.caption_path,
                             vocab,
                             transform,
                             args.batch_size,
                             shuffle=True,
                             num_workers=args.num_workers)

    print("Building Data Loader For Validation Set...")
    val_loader = get_loader(args.valid_dir,
                            args.valid_caption_path,
                            vocab,
                            transform,
                            args.batch_size,
                            shuffle=True,
                            num_workers=args.num_workers)

    # </editor-fold>

    # <editor-fold desc="Network Initialization">

    print("Setting up the trainer...")
    model_trainer = trainer(args, glove_emb, vocab)

    #  <\editor-fold desc="Network Initialization">

    for epoch in range(args.num_epochs):

        # <editor-fold desc = "Epoch Initialization"?

        # TRAINING TIME
        print('EPOCH ::: TRAINING ::: ' + str(epoch + 1))
        batch_time = AverageMeter()
        cm_losses = AverageMeter()
        end = time.time()

        bar = Bar('Training Net', max=len(data_loader))

        for i, (images, captions, lengths) in enumerate(data_loader):

            if i == len(data_loader) - 1:
                break

            images = to_var(images)
            captions = to_var(captions)
            lengths = to_var(
                torch.LongTensor(lengths))  # print(captions.size())

            img_rc_loss, txt_rc_loss = model_trainer.train(
                images, captions, lengths, not i % args.image_save_interval)

            txt_losses.update(txt_rc_loss.data[0], args.batch_size)
            img_losses.update(img_rc_loss.data[0], args.batch_size)
            # cm_losses.update(cm_loss.data[0], args.batch_size)

            batch_time.update(time.time() - end)
            end = time.time()

            # plot progress
            bar_suffix = '({batch}/{size}) Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:}'.format(
                batch=i,
                size=len(data_loader),
                bt=batch_time.avg,
                total=bar.elapsed_td,
                eta=bar.eta_td,
            )

            bar.next()

        # </editor-fold desc = "Logging">

        bar.finish()
        model_trainer.save_losses(epoch, img_losses.avg, txt_losses.avg)
        model_trainer.save_models(epoch)
Exemplo n.º 5
0
def main():
    # global args
    args = parser.parse_args()
    writer = SummaryWriter()

    # <editor-fold desc="Initialization">

    now = datetime.datetime.now()
    current_date = now.strftime("%m-%d-%H-%M")

    assert args.text_criterion in ("MSE", "Cosine", "Hinge",
                                   "NLLLoss"), 'Invalid Loss Function'
    assert args.cm_criterion in ("MSE", "Cosine",
                                 "Hinge"), 'Invalid Loss Function'

    mask = args.common_emb_size
    assert mask <= args.hidden_size

    cuda = args.cuda
    if cuda == 'true':
        cuda = True
    else:
        cuda = False

    model_path = args.model_path + current_date + args.comment + "/"

    result_path = args.result_path
    if result_path == "NONE":
        result_path = model_path + "results/"

    if not os.path.exists(result_path):
        os.makedirs(result_path)
    if not os.path.exists(model_path):
        os.makedirs(model_path)
    #</editor-fold>

    # <editor-fold desc="Image Preprocessing">

    # Image preprocessing //ATTENTION
    # For normalization, see https://github.com/pytorch/vision#models
    transform = transforms.Compose([
        transforms.RandomCrop(args.crop_size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])

    #</editor-fold>

    # <editor-fold desc="Creating Embeddings">

    # Load vocabulary wrapper.
    print("Loading Vocabulary...")
    with open(args.vocab_path, 'rb') as f:
        vocab = pickle.load(f)

    # Load Embeddings
    emb_size = args.embedding_size
    emb_path = args.embedding_path
    if args.embedding_path[-1] == '/':
        emb_path += 'glove.6B.' + str(emb_size) + 'd.txt'

    print("Loading Embeddings...")
    emb = load_glove_embeddings(emb_path, vocab.word2idx, emb_size)

    glove_emb = Embeddings(emb_size, len(vocab.word2idx),
                           vocab.word2idx["<pad>"])
    glove_emb.word_lut.weight.data.copy_(emb)
    glove_emb.word_lut.weight.requires_grad = False

    # glove_emb = nn.Embedding(emb.size(0), emb.size(1))
    # glove_emb = embedding(emb.size(0), emb.size(1))
    # glove_emb.weight = nn.Parameter(emb)

    # Freeze weighs
    # if args.fixed_embeddings == "true":
    # glove_emb.weight.requires_grad = False

    # </editor-fold>

    # <editor-fold desc="Data-Loaders">

    # Build data loader
    print("Building Data Loader For Test Set...")
    data_loader = get_loader(args.image_dir,
                             args.caption_path,
                             vocab,
                             transform,
                             args.batch_size,
                             shuffle=True,
                             num_workers=args.num_workers)

    print("Building Data Loader For Validation Set...")
    val_loader = get_loader(args.valid_dir,
                            args.valid_caption_path,
                            vocab,
                            transform,
                            args.batch_size,
                            shuffle=True,
                            num_workers=args.num_workers)

    # </editor-fold>

    # <editor-fold desc="Network Initialization">

    print("Setting up the Networks...")
    encoder_Txt = TextEncoder(glove_emb,
                              num_layers=1,
                              bidirectional=False,
                              hidden_size=args.hidden_size)
    decoder_Txt = TextDecoder(glove_emb,
                              num_layers=1,
                              bidirectional=False,
                              hidden_size=args.hidden_size)
    # decoder_Txt = TextDecoder(encoder_Txt, glove_emb)
    # decoder_Txt = DecoderRNN(glove_emb, hidden_size=args.hidden_size)

    encoder_Img = ImageEncoder(img_dimension=args.crop_size,
                               feature_dimension=args.hidden_size)
    decoder_Img = ImageDecoder(img_dimension=args.crop_size,
                               feature_dimension=args.hidden_size)

    if cuda:
        encoder_Txt = encoder_Txt.cuda()
        decoder_Img = decoder_Img.cuda()

        encoder_Img = encoder_Img.cuda()
        decoder_Txt = decoder_Txt.cuda()

    # </editor-fold>

    # <editor-fold desc="Losses">

    # Losses and Optimizers
    print("Setting up the Objective Functions...")
    img_criterion = nn.MSELoss()
    # txt_criterion = nn.MSELoss(size_average=True)
    if args.text_criterion == 'MSE':
        txt_criterion = nn.MSELoss()
    elif args.text_criterion == "Cosine":
        txt_criterion = nn.CosineEmbeddingLoss(size_average=False)
    else:
        txt_criterion = nn.HingeEmbeddingLoss(size_average=False)

    if args.cm_criterion == 'MSE':
        cm_criterion = nn.MSELoss()
    elif args.cm_criterion == "Cosine":
        cm_criterion = nn.CosineEmbeddingLoss()
    else:
        cm_criterion = nn.HingeEmbeddingLoss()

    if cuda:
        img_criterion = img_criterion.cuda()
        txt_criterion = txt_criterion.cuda()
        cm_criterion = cm_criterion.cuda()
    # txt_criterion = nn.CrossEntropyLoss()

    # </editor-fold>

    # <editor-fold desc="Optimizers">
    #     gen_params = chain(generator_A.parameters(), generator_B.parameters())
    print("Setting up the Optimizers...")
    # img_params = chain(decoder_Img.parameters(), encoder_Img.parameters())
    # txt_params = chain(decoder_Txt.decoder.parameters(), encoder_Txt.encoder.parameters())
    # img_params = list(decoder_Img.parameters()) + list(encoder_Img.parameters())
    # txt_params = list(decoder_Txt.decoder.parameters()) + list(encoder_Txt.encoder.parameters())

    # ATTENTION: Check betas and weight decay
    # ATTENTION: Check why valid_params fails on image networks with out of memory error

    # img_optim = optim.Adam(img_params, lr=0.0001, betas=(0.5, 0.999), weight_decay=0.00001)
    # txt_optim = optim.Adam(valid_params(txt_params), lr=0.0001,betas=(0.5, 0.999), weight_decay=0.00001)
    img_enc_optim = optim.Adam(
        encoder_Img.parameters(),
        lr=args.learning_rate)  #betas=(0.5, 0.999), weight_decay=0.00001)
    img_dec_optim = optim.Adam(
        decoder_Img.parameters(),
        lr=args.learning_rate)  #betas=(0.5,0.999), weight_decay=0.00001)
    txt_enc_optim = optim.Adam(
        valid_params(encoder_Txt.encoder.parameters()),
        lr=args.learning_rate)  #betas=(0.5,0.999), weight_decay=0.00001)
    txt_dec_optim = optim.Adam(
        valid_params(decoder_Txt.decoder.parameters()),
        lr=args.learning_rate)  #betas=(0.5,0.999), weight_decay=0.00001)

    # </editor-fold desc="Optimizers">

    train_images = False  # Reverse 2
    for epoch in range(args.num_epochs):

        # <editor-fold desc = "Epoch Initialization"?

        # TRAINING TIME
        print('EPOCH ::: TRAINING ::: ' + str(epoch + 1))
        batch_time = AverageMeter()
        txt_losses = AverageMeter()
        img_losses = AverageMeter()
        cm_losses = AverageMeter()
        end = time.time()

        bar = Bar('Training Net', max=len(data_loader))

        # Set training mode
        encoder_Img.train()
        decoder_Img.train()

        encoder_Txt.encoder.train()
        decoder_Txt.decoder.train()

        neg_rate = max(0, 2 * (10 - epoch) / 10)
        # </editor-fold desc = "Epoch Initialization"?

        train_images = not train_images
        for i, (images, captions, lengths) in enumerate(data_loader):
            # ATTENTION REMOVE
            if i == len(data_loader) - 1:
                break

            # <editor-fold desc = "Training Parameters Initiliazation"?

            # Set mini-batch dataset
            images = to_var(images)
            captions = to_var(captions)

            # target = pack_padded_sequence(captions, lengths, batch_first=True)[0]
            # captions, lengths = pad_sequences(captions, lengths)
            # images = torch.FloatTensor(images)

            captions = captions.transpose(0, 1).unsqueeze(2)
            lengths = torch.LongTensor(lengths)  # print(captions.size())

            # Forward, Backward and Optimize
            # img_optim.zero_grad()
            img_dec_optim.zero_grad()
            img_enc_optim.zero_grad()
            # encoder_Img.zero_grad()
            # decoder_Img.zero_grad()

            # txt_params.zero_grad()
            txt_dec_optim.zero_grad()
            txt_enc_optim.zero_grad()
            # encoder_Txt.encoder.zero_grad()
            # decoder_Txt.decoder.zero_grad()

            # </editor-fold desc = "Training Parameters Initiliazation"?

            # <editor-fold desc = "Image AE"?

            # Image Auto_Encoder Forward
            img_encoder_outputs, Iz = encoder_Img(images)

            IzI = decoder_Img(img_encoder_outputs)

            img_rc_loss = img_criterion(IzI, images)
            # </editor-fold desc = "Image AE"?

            # <editor-fold desc = "Seq2Seq AE"?
            # Text Auto Encoder Forward

            # target = target[:-1] # exclude last target from inputs

            captions = captions[:-1, :, :]
            lengths = lengths - 1
            dec_state = None

            encoder_outputs, memory_bank = encoder_Txt(captions, lengths)

            enc_state = \
                decoder_Txt.decoder.init_decoder_state(captions, memory_bank, encoder_outputs)

            decoder_outputs, dec_state, attns = \
                decoder_Txt.decoder(captions,
                             memory_bank,
                             enc_state if dec_state is None
                             else dec_state,
                             memory_lengths=lengths)

            Tz = encoder_outputs
            TzT = decoder_outputs

            # </editor-fold desc = "Seq2Seq AE"?

            # <editor-fold desc = "Loss accumulation"?
            if args.text_criterion == 'MSE':
                txt_rc_loss = txt_criterion(TzT, glove_emb(captions))
            else:
                txt_rc_loss = txt_criterion(TzT, glove_emb(captions),\
                                            Variable(torch.ones(TzT.size(0,1))).cuda())
            #
            # for x,y,l in zip(TzT.transpose(0,1),glove_emb(captions).transpose(0,1),lengths):
            #     if args.criterion == 'MSE':
            #         # ATTENTION dunno what's the right one
            #         txt_rc_loss += txt_criterion(x,y)
            #     else:
            #         # ATTENTION Fails on last batch
            #         txt_rc_loss += txt_criterion(x, y, Variable(torch.ones(x.size(0))).cuda())/l
            #
            # txt_rc_loss /= captions.size(1)

            # Computes Cross-Modal Loss

            Tz = Tz[0]

            txt = Tz.narrow(1, 0, mask)
            im = Iz.narrow(1, 0, mask)

            if args.cm_criterion == 'MSE':
                # cm_loss = cm_criterion(Tz.narrow(1,0,mask), Iz.narrow(1,0,mask))
                cm_loss = mse_loss(txt, im)
            else:
                cm_loss = cm_criterion(txt, im, \
                                       Variable(torch.ones(im.size(0)).cuda()))

            # K - Negative Samples
            k = args.negative_samples
            for _ in range(k):

                if cuda:
                    perm = torch.randperm(args.batch_size).cuda()
                else:
                    perm = torch.randperm(args.batch_size)

                # if args.criterion == 'MSE':
                #     cm_loss -= mse_loss(txt, im[perm])/k
                # else:
                #     cm_loss -= cm_criterion(txt, im[perm], \
                #                            Variable(torch.ones(Tz.narrow(1,0,mask).size(0)).cuda()))/k

                # sim  = (F.cosine_similarity(txt,txt[perm]) - 0.5)/2

                if args.cm_criterion == 'MSE':
                    sim = (F.cosine_similarity(txt, txt[perm]) - 1) / (2 * k)
                    # cm_loss = cm_criterion(Tz.narrow(1,0,mask), Iz.narrow(1,0,mask))
                    cm_loss += mse_loss(txt, im[perm], sim)
                else:
                    cm_loss += neg_rate * cm_criterion(txt, im[perm], \
                                           Variable(-1*torch.ones(txt.size(0)).cuda()))/k

            # cm_loss = Variable(torch.max(torch.FloatTensor([-0.100]).cuda(), cm_loss.data))

            # Computes the loss to be back-propagated
            img_loss = img_rc_loss * (
                1 - args.cm_loss_weight) + cm_loss * args.cm_loss_weight
            txt_loss = txt_rc_loss * (
                1 - args.cm_loss_weight) + cm_loss * args.cm_loss_weight
            # txt_loss = txt_rc_loss + 0.1 * cm_loss
            # img_loss = img_rc_loss + cm_loss

            txt_losses.update(txt_rc_loss.data[0], args.batch_size)
            img_losses.update(img_rc_loss.data[0], args.batch_size)
            cm_losses.update(cm_loss.data[0], args.batch_size)
            # </editor-fold desc = "Loss accumulation"?

            # <editor-fold desc = "Back Propagation">
            # Half of the times we update one pipeline the others the other one
            if train_images:
                # Image Network Training and Backpropagation

                img_loss.backward()
                # img_optim.step()
                img_enc_optim.step()
                img_dec_optim.step()

            else:
                # Text Nextwork Training & Back Propagation

                txt_loss.backward()
                # txt_optim.step()
                txt_enc_optim.step()
                txt_dec_optim.step()

            # </editor-fold desc = "Back Propagation">

            # <editor-fold desc = "Logging">
            if i % args.image_save_interval == 0:
                subdir_path = os.path.join(result_path,
                                           str(i / args.image_save_interval))

                if os.path.exists(subdir_path):
                    pass
                else:
                    os.makedirs(subdir_path)

                for im_idx in range(3):
                    im_or = (images[im_idx].cpu().data.numpy().transpose(
                        1, 2, 0) / 2 + .5) * 255
                    im = (IzI[im_idx].cpu().data.numpy().transpose(1, 2, 0) / 2
                          + .5) * 255

                    filename_prefix = os.path.join(subdir_path, str(im_idx))
                    scipy.misc.imsave(filename_prefix + '_original.A.jpg',
                                      im_or)
                    scipy.misc.imsave(filename_prefix + '.A.jpg', im)

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            # plot progress
            bar.suffix = '({batch}/{size}) Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss_Img: {img_l:.3f}| Loss_Txt: {txt_l:.3f} | Loss_CM: {cm_l:.4f}'.format(
                batch=i,
                size=len(data_loader),
                bt=batch_time.avg,
                total=bar.elapsed_td,
                eta=bar.eta_td,
                img_l=img_losses.avg,
                txt_l=txt_losses.avg,
                cm_l=cm_losses.avg,
            )
            bar.next()

            # </editor-fold desc = "Logging">

        bar.finish()

        # <editor-fold desc = "Saving the models"?
        # Save the models
        print('\n')
        print('Saving the models in {}...'.format(model_path))
        torch.save(
            decoder_Img.state_dict(),
            os.path.join(model_path, 'decoder-img-%d-' % (epoch + 1)) +
            current_date + ".pkl")
        torch.save(
            encoder_Img.state_dict(),
            os.path.join(model_path, 'encoder-img-%d-' % (epoch + 1)) +
            current_date + ".pkl")
        torch.save(
            decoder_Txt.state_dict(),
            os.path.join(model_path, 'decoder-txt-%d-' % (epoch + 1)) +
            current_date + ".pkl")
        torch.save(
            encoder_Txt.state_dict(),
            os.path.join(model_path, 'encoder-txt-%d-' % (epoch + 1)) +
            current_date + ".pkl")

        # </editor-fold desc = "Saving the models"?
        # <editor-fold desc = "Validation">
        if args.validate == "true":
            print("Train Set")
            validate(encoder_Img, encoder_Txt, data_loader, mask, 10)

            print("Test Set")
            validate(encoder_Img, encoder_Txt, val_loader, mask, 10)

        # </editor-fold desc = "Validation">

        writer.add_scalars(
            'data/scalar_group', {
                'Image_RC': img_losses.avg,
                'Text_RC': txt_losses.avg,
                'CM_loss': cm_losses.avg
            }, epoch)
Exemplo n.º 6
0
def main():
    # global args
    args = parser.parse_args()

    # <editor-fold desc="Initialization">

    now = datetime.datetime.now()
    current_date = now.strftime("%m-%d-%H-%M")

    assert args.text_criterion in ("MSE","Cosine","Hinge","NLLLoss"), 'Invalid Loss Function'
    assert args.cm_criterion in ("MSE","Cosine","Hinge"), 'Invalid Loss Function'

    mask = int(args.common_emb_percentage * args.hidden_size)
    assert mask <= args.hidden_size

    cuda = args.cuda
    if cuda == 'true':
        cuda = True
    else:
        cuda = False

    if args.load_model == "NONE":
        keep_loading = True
        model_path = args.model_path + current_date + "/"
    else:
        keep_loading = False
        model_path = args.model_path + args.load_model + "/"

    result_path = args.result_path
    if result_path == "NONE":
        result_path = model_path + "results/"




    if not os.path.exists(result_path):
        os.makedirs(result_path)
    if not os.path.exists(model_path):
        os.makedirs(model_path)
    #</editor-fold>

    # <editor-fold desc="Image Preprocessing">

    # Image preprocessing //ATTENTION
    # For normalization, see https://github.com/pytorch/vision#models
    transform = transforms.Compose([
        transforms.RandomCrop(args.crop_size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406),
                             (0.229, 0.224, 0.225))])

    #</editor-fold>

    # <editor-fold desc="Creating Embeddings">


    # Load vocabulary wrapper.
    print("Loading Vocabulary...")
    with open(args.vocab_path, 'rb') as f:
        vocab = pickle.load(f)

    # Load Embeddings
    emb_size = args.embedding_size
    emb_path = args.embedding_path
    if args.embedding_path[-1]=='/':
        emb_path += 'glove.6B.' + str(emb_size) + 'd.txt'

    print("Loading Embeddings...")
    emb = load_glove_embeddings(emb_path, vocab.word2idx, emb_size)

    # glove_emb = Embeddings(emb_size,len(vocab.word2idx),vocab.word2idx["<pad>"])
    # glove_emb.word_lut.weight.data.copy_(emb)
    # glove_emb.word_lut.weight.requires_grad = False

    glove_emb = nn.Embedding(emb.size(0), emb.size(1))
    # glove_emb = embedding(emb.size(0), emb.size(1))
    # glove_emb.weight = nn.Parameter(emb)


    # Freeze weighs
    # if args.fixed_embeddings == "true":
        # glove_emb.weight.requires_grad = False


    # </editor-fold>

    # <editor-fold desc="Data-Loaders">

    # Build data loader
    print("Building Data Loader For Test Set...")
    data_loader = get_loader(args.image_dir, args.caption_path, vocab,
                             transform, args.batch_size,
                             shuffle=True, num_workers=args.num_workers)

    print("Building Data Loader For Validation Set...")
    val_loader = get_loader(args.valid_dir, args.valid_caption_path, vocab,
                             transform, args.batch_size,
                             shuffle=True, num_workers=args.num_workers)

    # </editor-fold>

    # <editor-fold desc="Network Initialization">

    print("Setting up the Networks...")
    encoder_Txt = TextEncoder(glove_emb, num_layers=1, bidirectional=False, hidden_size=args.hidden_size)
    decoder_Txt = TextDecoder(glove_emb, len(vocab),  num_layers=1, bidirectional=False, hidden_size=args.hidden_size)
    # decoder_Txt = TextDecoder(encoder_Txt, glove_emb)
    # decoder_Txt = DecoderRNN(glove_emb, hidden_size=args.hidden_size)


    encoder_Img = ImageEncoder(img_dimension=args.crop_size,feature_dimension= args.hidden_size)
    decoder_Img = ImageDecoder(img_dimension=args.crop_size, feature_dimension= args.hidden_size)

    if cuda:
        encoder_Txt = encoder_Txt.cuda()
        decoder_Img = decoder_Img.cuda()

        encoder_Img = encoder_Img.cuda()
        decoder_Txt = decoder_Txt.cuda()

    # </editor-fold>

    # <editor-fold desc="Losses">

    # Losses and Optimizers
    print("Setting up the Objective Functions...")
    img_criterion = nn.MSELoss()
    # txt_criterion = nn.MSELoss(size_average=True)
    if args.text_criterion == 'MSE':
        txt_criterion = nn.MSELoss()
    elif args.text_criterion == "Cosine":
        txt_criterion = nn.CosineEmbeddingLoss(size_average=False)
    elif args.text_criterion == "NLLLoss":
        txt_criterion = nn.NLLLoss()
    else:
        txt_criterion = nn.HingeEmbeddingLoss(size_average=False)

    if args.cm_criterion == 'MSE':
        cm_criterion = nn.MSELoss()
    elif args.cm_criterion == "Cosine":
        cm_criterion = nn.CosineEmbeddingLoss()
    else:
        cm_criterion = nn.HingeEmbeddingLoss()


    if cuda:
        img_criterion = img_criterion.cuda()
        txt_criterion = txt_criterion.cuda()
        cm_criterion = cm_criterion.cuda()
    # txt_criterion = nn.CrossEntropyLoss()

    # </editor-fold>

    # <editor-fold desc="Optimizers">
    #     gen_params = chain(generator_A.parameters(), generator_B.parameters())
    print("Setting up the Optimizers...")
    # img_params = chain(decoder_Img.parameters(), encoder_Img.parameters())
    # txt_params = chain(decoder_Txt.decoder.parameters(), encoder_Txt.encoder.parameters())
    # img_params = list(decoder_Img.parameters()) + list(encoder_Img.parameters())
    # txt_params = list(decoder_Txt.decoder.parameters()) + list(encoder_Txt.encoder.parameters())

    # ATTENTION: Check betas and weight decay
    # ATTENTION: Check why valid_params fails on image networks with out of memory error

    # img_optim = optim.Adam(img_params, lr=0.0001, betas=(0.5, 0.999), weight_decay=0.00001)
    # txt_optim = optim.Adam(valid_params(txt_params), lr=0.0001,betas=(0.5, 0.999), weight_decay=0.00001)
    img_enc_optim = optim.Adam(encoder_Img.parameters(), lr=args.learning_rate)#betas=(0.5, 0.999), weight_decay=0.00001)
    img_dec_optim = optim.Adam(decoder_Img.parameters(), lr=args.learning_rate)#betas=(0.5,0.999), weight_decay=0.00001)
    txt_enc_optim = optim.Adam(valid_params(encoder_Txt.parameters()), lr=args.learning_rate)#betas=(0.5,0.999), weight_decay=0.00001)
    txt_dec_optim = optim.Adam(valid_params(decoder_Txt.parameters()), lr=args.learning_rate)#betas=(0.5,0.999), weight_decay=0.00001)

    # </editor-fold desc="Optimizers">

    train_images = False # Reverse 2

    for epoch in range(args.num_epochs):

        # <editor-fold desc = "Epoch Initialization"?

        # TRAINING TIME
        print('EPOCH ::: TRAINING ::: ' + str(epoch + 1))
        batch_time = AverageMeter()
        txt_losses = AverageMeter()
        img_losses = AverageMeter()
        cm_losses = AverageMeter()
        end = time.time()

        bar = Bar('Training Net', max=len(data_loader))

        if keep_loading:
            suffix = "-" + str(epoch) + "-" + args.load_model + ".pkl"
            try:
                encoder_Img.load_state_dict(torch.load(os.path.join(args.model_path,
                                        'encoder-img' + suffix)))
                encoder_Txt.load_state_dict(torch.load(os.path.join(args.model_path,
                                        'encoder-txt' + suffix)))
                decoder_Img.load_state_dict(torch.load(os.path.join(args.model_path,
                                        'decoder-img' + suffix)))
                decoder_Txt.load_state_dict(torch.load(os.path.join(args.model_path,
                                        'decoder-txt' + suffix)))
            except FileNotFoundError:
                print("Didn't find any models switching to training")
                keep_loading = False

        if not keep_loading:

            # Set training mode
            encoder_Img.train()
            decoder_Img.train()

            encoder_Txt.train()
            decoder_Txt.train()

            # </editor-fold desc = "Epoch Initialization"?

            train_images = not train_images
            for i, (images, captions, lengths) in enumerate(data_loader):

                if i == len(data_loader)-1:
                    break


                # <editor-fold desc = "Training Parameters Initiliazation"?

                # Set mini-batch dataset
                images = to_var(images)
                captions = to_var(captions)

                # target = pack_padded_sequence(captions, lengths, batch_first=True)[0]
                # captions, lengths = pad_sequences(captions, lengths)
                # images = torch.FloatTensor(images)

                captions = captions.transpose(0,1).unsqueeze(2)
                lengths = to_var(torch.LongTensor(lengths))            # print(captions.size())


                # Forward, Backward and Optimize
                # img_optim.zero_grad()
                img_dec_optim.zero_grad()
                img_enc_optim.zero_grad()
                # encoder_Img.zero_grad()
                # decoder_Img.zero_grad()

                # txt_params.zero_grad()
                txt_dec_optim.zero_grad()
                txt_enc_optim.zero_grad()
                # encoder_Txt.encoder.zero_grad()
                # decoder_Txt.decoder.zero_grad()

                # </editor-fold desc = "Training Parameters Initiliazation"?

                # <editor-fold desc = "Image AE"?

                # Image Auto_Encoder Forward
                mu, logvar  = encoder_Img(images)

                Iz = logvar
                # Iz = reparametrize(mu, logvar)
                IzI = decoder_Img(mu)

                img_rc_loss = img_criterion(IzI,images)
                # </editor-fold desc = "Image AE"?

                # <editor-fold desc = "Seq2Seq AE"?
                # Text Auto Encoder Forward

                # target = target[:-1] # exclude last target from inputs

                teacher_forcing_ratio = 0.5

                encoder_hidden = encoder_Txt.initHidden(args.batch_size)

                input_length = captions.size(0)
                target_length = captions.size(0)

                if cuda:
                    encoder_outputs = Variable(torch.zeros(input_length, args.batch_size, args.hidden_size).cuda())
                    decoder_outputs = Variable(torch.zeros(input_length, args.batch_size, len(vocab)).cuda())
                else:
                    encoder_outputs = Variable(torch.zeros(input_length, args.batch_size, args.hidden_size))
                    decoder_outputs = Variable(torch.zeros(input_length, args.batch_size, len(vocab)))

                txt_rc_loss = 0

                for ei in range(input_length):
                    encoder_output, encoder_hidden = encoder_Txt(
                    captions[ei,:], encoder_hidden)
                    encoder_outputs[ei] = encoder_output

                decoder_input = Variable(torch.LongTensor([vocab.word2idx['<start>']])).cuda()\
                    .repeat(args.batch_size,1)


                decoder_hidden = encoder_hidden

                use_teacher_forcing = True #if np.random.random() < teacher_forcing_ratio else False

                if use_teacher_forcing:
                    # Teacher forcing: Feed the target as the next input
                    for di in range(target_length-1):
                        decoder_output, decoder_hidden = decoder_Txt(
                        decoder_input, decoder_hidden) #, encoder_outputs)
                # txt_rc_loss += txt_criterion(decoder_output, captions[di].unsqueeze(1))

                        decoder_outputs[di] = decoder_output

                        decoder_input = captions[di+1]  # Teacher forcing

                else:
                # Without teacher forcing: use its own predictions as the next input
                    for di in range(target_length-1):
                        decoder_outputs, decoder_hidden = decoder_Txt(
                        decoder_input, decoder_hidden)
                        topv, topi = decoder_output.topk(1)
                        decoder_input = topi.squeeze().detach()  # detach from history as input

                        txt_rc_loss += txt_criterion(decoder_output, captions[di])
                # if decoder_input.item() == ("<end>"):
                #     break

                # Check start tokens etc
                txt_rc_loss, _, _, _ = masked_cross_entropy(
                decoder_outputs[:target_length-1].transpose(0, 1).contiguous(),
                                captions[1:,:,0].transpose(0, 1).contiguous(),
                                lengths - 1
                )


                # captions = captions[:-1,:,:]
                # lengths = lengths - 1
                # dec_state = None

                # Computes Cross-Modal Loss

                # Tz = encoder_hidden[0]
                Tz = encoder_output[:,0,:]

                txt =  Tz.narrow(1,0,mask)
                im = Iz.narrow(1,0,mask)

                if args.cm_criterion == 'MSE':
                    # cm_loss = cm_criterion(Tz.narrow(1,0,mask), Iz.narrow(1,0,mask))
                    cm_loss = mse_loss(txt, im)
                else:
                    cm_loss = cm_criterion(txt, im, \
                    Variable(torch.ones(im.size(0)).cuda()))

                # K - Negative Samples
                k = args.negative_samples
                neg_rate = (20-epoch)/20
                for _ in range(k):

                    if cuda:
                        perm = torch.randperm(args.batch_size).cuda()
                    else:
                        perm = torch.randperm(args.batch_size)

                    # if args.criterion == 'MSE':
                    #     cm_loss -= mse_loss(txt, im[perm])/k
                    # else:
                    #     cm_loss -= cm_criterion(txt, im[perm], \
                    #                            Variable(torch.ones(Tz.narrow(1,0,mask).size(0)).cuda()))/k

                    # sim  = (F.cosine_similarity(txt,txt[perm]) - 0.5)/2

                    if args.cm_criterion == 'MSE':
                        sim  = (F.cosine_similarity(txt,txt[perm]) - 1)/(2*k)
                        # cm_loss = cm_criterion(Tz.narrow(1,0,mask), Iz.narrow(1,0,mask))
                        cm_loss += mse_loss(txt, im[perm], sim)
                    else:
                        cm_loss += neg_rate * cm_criterion(txt, im[perm], \
                        Variable(-1*torch.ones(txt.size(0)).cuda()))/k


                # cm_loss = Variable(torch.max(torch.FloatTensor([-0.100]).cuda(), cm_loss.data))


                # Computes the loss to be back-propagated
                img_loss = img_rc_loss * (1 - args.cm_loss_weight) + cm_loss * args.cm_loss_weight
                txt_loss = txt_rc_loss * (1 - args.cm_loss_weight) + cm_loss * args.cm_loss_weight
                # txt_loss = txt_rc_loss + 0.1 * cm_loss
                # img_loss = img_rc_loss + cm_loss

                txt_losses.update(txt_rc_loss.data[0],args.batch_size)
                img_losses.update(img_rc_loss.data[0],args.batch_size)
                cm_losses.update(cm_loss.data[0], args.batch_size)
                # </editor-fold desc = "Loss accumulation"?

                # <editor-fold desc = "Back Propagation">
                # Half of the times we update one pipeline the others the other one
                if train_images:
                # Image Network Training and Backpropagation

                    img_loss.backward()
                    # img_optim.step()
                    img_enc_optim.step()
                    img_dec_optim.step()

                else:
                    # Text Nextwork Training & Back Propagation

                    txt_loss.backward()
                    # txt_optim.step()
                    txt_enc_optim.step()
                    txt_dec_optim.step()

                train_images = not train_images
                # </editor-fold desc = "Back Propagation">

                # <editor-fold desc = "Logging">
                if i % args.image_save_interval == 0:
                    subdir_path = os.path.join( result_path, str(i / args.image_save_interval) )

                    if os.path.exists( subdir_path ):
                        pass
                    else:
                        os.makedirs( subdir_path )

                    for im_idx in range(3):
                        im_or = (images[im_idx].cpu().data.numpy().transpose(1,2,0)/2+.5)*255
                        im = (IzI[im_idx].cpu().data.numpy().transpose(1,2,0)/2+.5)*255

                        filename_prefix = os.path.join (subdir_path, str(im_idx))
                        scipy.misc.imsave( filename_prefix + '_original.A.jpg', im_or)
                        scipy.misc.imsave( filename_prefix + '.A.jpg', im)


                        txt_or = " ".join([vocab.idx2word[c] for c in list(captions[:,im_idx].view(-1).cpu().data)])
                        txt = " ".join([vocab.idx2word[c] for c in list(decoder_outputs[:,im_idx].view(-1).cpu().data)])
                        print("Original: ", txt_or)
                        print(txt)


                # measure elapsed time
                batch_time.update(time.time() - end)
                end = time.time()

                # plot progress
                bar.suffix = '({batch}/{size}) Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss_Img: {img_l:.3f}| Loss_Txt: {txt_l:.3f} | Loss_CM: {cm_l:.4f}'.format(
                    batch=i,
                    size=len(data_loader),
                    bt=batch_time.avg,
                    total=bar.elapsed_td,
                    eta=bar.eta_td,
                    img_l=img_losses.avg,
                    txt_l=txt_losses.avg,
                    cm_l=cm_losses.avg,
                    )
                bar.next()

                                                                         # </editor-fold desc = "Logging">

            bar.finish()

            # <editor-fold desc = "Saving the models"?
            # Save the models
            print('\n')
            print('Saving the models in {}...'.format(model_path))
            torch.save(decoder_Img.state_dict(),
                       os.path.join(model_path,
                                    'decoder-img-%d-' %(epoch+1)) + current_date + ".pkl")
            torch.save(encoder_Img.state_dict(),
                       os.path.join(model_path,
                                    'encoder-img-%d-' %(epoch+1)) + current_date + ".pkl")
            torch.save(decoder_Txt.state_dict(),
                       os.path.join(model_path,
                                    'decoder-txt-%d-' %(epoch+1)) + current_date + ".pkl")
            torch.save(encoder_Txt.state_dict(),
                       os.path.join(model_path,
                                    'encoder-txt-%d-' %(epoch+1)) + current_date + ".pkl")

            # </editor-fold desc = "Saving the models"?

        if args.validate == "true":
            validate(encoder_Img, encoder_Txt, val_loader, mask, 10)
Exemplo n.º 7
0
def main():
    # global args
    args = parser.parse_args()

    assert args.criterion in ("MSE","Cosine","Hinge"), 'Invalid Loss Function'

    cuda = args.cuda
    if cuda == 'true':
        cuda = True
    else:
        cuda = False

    # Image preprocessing //ATTENTION
    # For normalization, see https://github.com/pytorch/vision#models
    transform = transforms.Compose([
        transforms.RandomCrop(args.crop_size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406),
                             (0.229, 0.224, 0.225))])

    result_path = args.result_path
    model_path = args.model_path

    if not os.path.exists(result_path):
        os.makedirs(result_path)
    if not os.path.exists(model_path):
        os.makedirs(model_path)


    # Load vocabulary wrapper.
    print('\n')
    print("\033[94mLoading Vocabulary...")
    with open(args.vocab_path, 'rb') as f:
        vocab = pickle.load(f)

    # Load Embeddings
    emb_size = args.embedding_size
    emb_path = args.embedding_path
    if args.embedding_path[-1]=='/':
        emb_path += 'glove.6B.' + str(emb_size) + 'd.txt'

    print("Loading Embeddings...")
    emb = load_glove_embeddings(emb_path, vocab.word2idx, emb_size)

    glove_emb = Embeddings(emb_size,len(vocab.word2idx),vocab.word2idx["<pad>"])
    glove_emb.word_lut.weight.data.copy_(emb)
    glove_emb.word_lut.weight.requires_grad = False

    # glove_emb = nn.Embedding(emb.size(0), emb.size(1))
    # glove_emb = embedding(emb.size(0), emb.size(1))
    # glove_emb.weight = nn.Parameter(emb)


    # Freeze weighs
    # if args.fixed_embeddings == "true":
        # glove_emb.weight.requires_grad = False

    # Build data loader
    print("Building Data Loader For Test Set...")
    data_loader = get_loader(args.image_dir, args.caption_path, vocab,

                             transform, args.batch_size,
                             shuffle=True, num_workers=args.num_workers)

    print("Building Data Loader For Validation Set...")
    val_loader = get_loader(args.valid_dir, args.valid_caption_path, vocab,
                             transform, args.batch_size,
                             shuffle=True, num_workers=args.num_workers)

    print("Setting up the Networks...")
    encoder_Txt = TextEncoderOld(glove_emb, num_layers=1, bidirectional=False, hidden_size=args.hidden_size)
    # decoder_Txt = TextDecoderOld(glove_emb, num_layers=1, bidirectional=False, hidden_size=args.hidden_size)
    # decoder_Txt = TextDecoder(encoder_Txt, glove_emb)
    # decoder_Txt = DecoderRNN(glove_emb, hidden_size=args.hidden_size)


    encoder_Img = ImageEncoder(img_dimension=args.crop_size,feature_dimension= args.hidden_size)
    # decoder_Img = ImageDecoder(img_dimension=args.crop_size, feature_dimension= args.hidden_size)

    if cuda:
        encoder_Txt = encoder_Txt.cuda()

        encoder_Img = encoder_Img.cuda()


    for epoch in range(args.num_epochs):


        # VALIDATION TIME
        print('\033[92mEPOCH ::: VALIDATION ::: ' + str(epoch + 1))

        # Load the models
        print("Loading the models...")

        # suffix = '-{}-05-28-13-14.pkl'.format(epoch+1)
        # mask = 300

        prefix = ""
        suffix = '-{}-05-28-09-23.pkl'.format(epoch+1)
        # suffix = '-{}-05-28-11-35.pkl'.format(epoch+1)
        # suffix = '-{}-05-28-16-45.pkl'.format(epoch+1)
        # suffix = '-{}-05-29-00-28.pkl'.format(epoch+1)
        # suffix = '-{}-05-29-00-30.pkl'.format(epoch+1)
        # suffix = '-{}-05-29-01-08.pkl'.format(epoch+1)
        mask = 200

        # suffix = '-{}-05-28-15-39.pkl'.format(epoch+1)
        # suffix = '-{}-05-29-12-11.pkl'.format(epoch+1)
        # suffix = '-{}-05-29-12-14.pkl'.format(epoch+1)
        # suffix = '-{}-05-29-14-24.pkl'.format(epoch+1) #best
        # suffix = '-{}-05-29-15-43.pkl'.format(epoch+1)
        date = "06-30-14-22"
        date = "07-01-12-49" #bad
        date = "07-01-16-38"
        date = "07-01-18-16"
        date = "07-02-15-38"
        date = "07-08-15-12"
        prefix = "{}/".format(date)
        suffix = '-{}-{}.pkl'.format(epoch+1,date)
        mask = 100

        print(suffix)
        try:
            encoder_Img.load_state_dict(torch.load(os.path.join(args.model_path,
                                    prefix + 'encoder-img' + suffix)))
            encoder_Txt.load_state_dict(torch.load(os.path.join(args.model_path,
                                    prefix + 'encoder-txt' + suffix)))
        except FileNotFoundError:
            print("\n\033[91mFile not found...\nTerminating Validation Procedure!")
            break

            current_embeddings = np.concatenate( \
                (txt_emb.cpu().data.numpy(),\
                 img_emb.unsqueeze(0).cpu().data.numpy())\
                ,0)

            # current_embeddings = img_emb.data
            if i:
                # result_embeddings = torch.cat( \
                result_embeddings = np.concatenate( \
                    (result_embeddings, current_embeddings) \
                    ,1)
            else:
                result_embeddings = current_embeddings

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            # plot progress
            bar.suffix = '({batch}/{size}) Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:}'.format(
                        batch=i,
                        size=len(val_loader),
                        bt=batch_time.avg,
                        total=bar.elapsed_td,
                        eta=bar.eta_td,
                        )
            bar.next()
        bar.finish()


        a = [((result_embeddings[0][i] - result_embeddings[1][i]) ** 2).mean() for i in range(limit*args.batch_size)]
        print("Validation MSE: ",np.mean(a))
        print("Validation MSE: ",np.mean(a))

        print("Computing Nearest Neighbors...")
        i = 0
        topk = []
        kss = [1,10,50]
        for k in kss:

            if i:
                print("Normalized ")
                result_embeddings[0] = result_embeddings[0]/result_embeddings[0].sum()
                result_embeddings[1] = result_embeddings[1]/result_embeddings[1].sum()

            # k = 5
            neighbors = NearestNeighbors(k, metric = 'cosine')
            neigh = neighbors
            neigh.fit(result_embeddings[1])
            kneigh = neigh.kneighbors(result_embeddings[0], return_distance=False)

            ks = set()
            for n in kneigh:
                ks.update(set(n))

            print(len(ks)/result_embeddings.shape[1])

            # a = [((result_embeddings[0][i] - result_embeddings[1][i]) ** 2).mean() for i in range(128)]
            # rs = result_embeddings.sum(2)
            # a = (((result_embeddings[0][0]- result_embeddings[1][0])**2).mean())
            # b = (((result_embeddings[0][0]- result_embeddings[0][34])**2).mean())
            topk.append(np.mean([int(i in nn) for i,nn in enumerate(kneigh)]))

        print("Top-{k:},{k2:},{k3:} accuracy for Image Retrieval:\n\n\t\033[95m {tpk: .3f}% \t {tpk2: .3f}% \t {tpk3: .3f}% \n".format(
                      k=kss[0],
                      k2=kss[1],
                      k3=kss[2],
                      tpk= 100*topk[0],
                      tpk2= 100*topk[1],
                      tpk3= 100*topk[2]))
Exemplo n.º 8
0
def main():
    # global args
    args = parser.parse_args()

    # <editor-fold desc="Initialization">
    if args.comment == "test":
        print("WARNING: name is test!!!\n\n")

    # now = datetime.datetime.now()
    # current_date = now.strftime("%m-%d-%H-%M")

    assert args.text_criterion in ("MSE", "Cosine", "Hinge",
                                   "NLLLoss"), 'Invalid Loss Function'
    assert args.cm_criterion in ("MSE", "Cosine",
                                 "Hinge"), 'Invalid Loss Function'

    assert args.common_emb_ratio <= 1.0 and args.common_emb_ratio >= 0

    mask = int(args.common_emb_ratio * args.hidden_size)

    cuda = args.cuda
    if cuda == 'true':
        cuda = True
    else:
        cuda = False

    if args.load_model == "NONE":
        keep_loading = False
        # model_path = args.model_path + current_date + "/"
        model_path = args.model_path + args.comment + "/"
    else:
        keep_loading = True
        model_path = args.model_path + args.load_model + "/"

    result_path = args.result_path
    if result_path == "NONE":
        result_path = model_path + "results/"

    if not os.path.exists(result_path):
        os.makedirs(result_path)
    if not os.path.exists(model_path):
        os.makedirs(model_path)
    #</editor-fold>

    # <editor-fold desc="Image Preprocessing">

    # Image preprocessing //ATTENTION
    # For normalization, see https://github.com/pytorch/vision#models
    transform = transforms.Compose([
        transforms.RandomCrop(args.crop_size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])

    inv_normalize = transforms.Normalize(
        mean=[-0.485 / 0.229, -0.456 / 0.224, -0.406 / 0.255],
        std=[1 / 0.229, 1 / 0.224, 1 / 0.255])
    #</editor-fold>

    # <editor-fold desc="Creating Embeddings">

    # Load vocabulary wrapper.
    print("Loading Vocabulary...")
    with open(args.vocab_path, 'rb') as f:
        vocab = pickle.load(f)

    # Load Embeddings
    emb_size = args.word_embedding_size
    emb_path = args.embedding_path
    if args.embedding_path[-1] == '/':
        emb_path += 'glove.6B.' + str(emb_size) + 'd.txt'

    print("Loading Embeddings...")
    emb = load_glove_embeddings(emb_path, vocab.word2idx, emb_size)

    # glove_emb = Embeddings(emb_size,len(vocab.word2idx),vocab.word2idx["<pad>"])
    # glove_emb.word_lut.weight.data.copy_(emb)
    # glove_emb.word_lut.weight.requires_grad = False

    glove_emb = nn.Embedding(emb.size(0), emb.size(1))
    # glove_emb = embedding(emb.size(0), emb.size(1))
    # glove_emb.weight = nn.Parameter(emb)

    # Freeze weighs
    # if args.fixed_embeddings == "true":
    # glove_emb.weight.requires_grad = False

    # </editor-fold>

    # <editor-fold desc="Data-Loaders">

    # Build data loader
    print("Building Data Loader For Test Set...")
    data_loader = get_loader(args.image_dir,
                             args.caption_path,
                             vocab,
                             transform,
                             args.batch_size,
                             shuffle=True,
                             num_workers=args.num_workers)

    print("Building Data Loader For Validation Set...")
    val_loader = get_loader(args.valid_dir,
                            args.valid_caption_path,
                            vocab,
                            transform,
                            args.batch_size,
                            shuffle=True,
                            num_workers=args.num_workers)

    # </editor-fold>

    # <editor-fold desc="Network Initialization">

    print("Setting up the Networks...")
    vae_Txt = SentenceVAE(glove_emb,
                          len(vocab),
                          hidden_size=args.hidden_size,
                          latent_size=args.latent_size,
                          batch_size=args.batch_size)
    vae_Img = ImgVAE(img_dimension=args.crop_size,
                     hidden_size=args.hidden_size,
                     latent_size=args.latent_size)

    if cuda:
        vae_Txt = vae_Txt.cuda()
        vae_Img = vae_Img.cuda()

    # </editor-fold>

    # <editor-fold desc="Losses">

    # Losses and Optimizers
    print("Setting up the Objective Functions...")
    img_criterion = nn.MSELoss()
    # txt_criterion = nn.MSELoss(size_average=True)
    if args.text_criterion == 'MSE':
        txt_criterion = nn.MSELoss()
    elif args.text_criterion == "Cosine":
        txt_criterion = nn.CosineEmbeddingLoss(size_average=False)
    elif args.text_criterion == "NLLLoss":
        txt_criterion = nn.NLLLoss()
    else:
        txt_criterion = nn.HingeEmbeddingLoss(size_average=False)

    if args.cm_criterion == 'MSE':
        cm_criterion = nn.MSELoss()
    elif args.cm_criterion == "Cosine":
        cm_criterion = nn.CosineEmbeddingLoss()
    else:
        cm_criterion = nn.HingeEmbeddingLoss()

    if cuda:
        img_criterion = img_criterion.cuda()
        txt_criterion = txt_criterion.cuda()
        cm_criterion = cm_criterion.cuda()
    # txt_criterion = nn.CrossEntropyLoss()

    # </editor-fold>

    # <editor-fold desc="Optimizers">
    print("Setting up the Optimizers...")

    img_optim = optim.Adam(vae_Img.parameters(),
                           lr=args.learning_rate,
                           betas=(0.5, 0.999),
                           weight_decay=0.00001)
    txt_optim = optim.Adam(vae_Txt.parameters(),
                           lr=args.learning_rate,
                           betas=(0.5, 0.999),
                           weight_decay=0.00001)

    # </editor-fold desc="Optimizers">

    train_images = True  # Reverse 2

    step = 0
    for epoch in range(args.num_epochs):

        # <editor-fold desc = "Epoch Initialization"?

        # TRAINING TIME
        print('EPOCH ::: TRAINING ::: ' + str(epoch + 1))
        batch_time = AverageMeter()
        txt_losses = AverageMeter()
        img_losses = AverageMeter()
        cm_losses = AverageMeter()
        end = time.time()

        bar = Bar('Training Net', max=len(data_loader))

        if keep_loading:
            suffix = "-" + str(epoch) + "-" + args.load_model + ".pkl"
            try:
                vae_Img.load_state_dict(
                    torch.load(
                        os.path.join(args.model_path, 'vae-img' + suffix)))
                vae_Txt.load_state_dict(
                    torch.load(
                        os.path.join(args.model_path, 'vae-txt' + suffix)))
            except FileNotFoundError:
                print("Didn't find any models switching to training")
                keep_loading = False

        if not keep_loading:

            # Set training mode
            vae_Txt.train()
            vae_Img.train()

            # </editor-fold desc = "Epoch Initialization"?

            # train_images = not train_images
            for i, (images, captions, lengths) in enumerate(data_loader):

                if i == len(data_loader) - 1:
                    break

                # <editor-fold desc = "Training Parameters Initiliazation"?

                # Set mini-batch dataset
                images = to_var(images)
                captions = to_var(captions)

                # captions = captions.transpose(0,1).unsqueeze(2)
                lengths = to_var(
                    torch.LongTensor(lengths))  # print(captions.size())

                # Forward, Backward and Optimize
                img_optim.zero_grad()
                txt_optim.zero_grad()

                # </editor-fold desc = "Training Parameters Initiliazation"?

                # <editor-fold desc = "Forward passes"?

                img_out, img_mu, img_logv, img_z = vae_Img(images)
                txt_out, txt_mu, txt_logv, txt_z = vae_Txt(captions, lengths)

                img_rc_loss = img_vae_loss(
                    img_out, images, img_mu,
                    img_logv) / (args.batch_size * args.crop_size**2)

                NLL_loss, KL_loss, KL_weight = seq_vae_loss(
                    txt_out, captions, lengths, txt_mu, txt_logv, "logistic",
                    step, 0.0025, 2500)

                txt_rc_loss = (NLL_loss + KL_weight *
                               KL_loss) / torch.sum(lengths).float()

                cm_loss = crossmodal_loss(txt_z, img_z, mask,
                                          args.cm_criterion, cm_criterion,
                                          args.negative_samples, epoch)

                # cm_loss += crossmodal_loss(txt_logv, img_logv, mask,
                #                           args.cm_criterion, cm_criterion,
                #                           args.negative_samples, epoch)

                # Computes the loss to be back-propagated
                img_loss = img_rc_loss * (
                    1 - args.cm_loss_weight) + cm_loss * args.cm_loss_weight
                txt_loss = txt_rc_loss * (
                    1 - args.cm_loss_weight) + cm_loss * args.cm_loss_weight
                # txt_loss = txt_rc_loss +  cm_loss * args.cm_loss_weight
                # img_loss = img_rc_loss + cm_loss * args.cm_loss_weight

                txt_losses.update(txt_rc_loss.data[0], args.batch_size)
                img_losses.update(img_rc_loss.data[0], args.batch_size)
                cm_losses.update(cm_loss.data[0], args.batch_size)
                # </editor-fold desc = "Loss accumulation"?

                # <editor-fold desc = "Back Propagation">
                # Half of the times we update one pipeline the others the other one
                if train_images:
                    # Image Network Training and Backpropagation

                    img_loss.backward()
                    img_optim.step()

                else:
                    # Text Nextwork Training & Back Propagation
                    txt_loss.backward()
                    txt_optim.step()

                    step += 1

                # train_images = not train_images
                # </editor-fold desc = "Back Propagation">

                # <editor-fold desc = "Logging">
                if i % args.image_save_interval == 0:
                    subdir_path = os.path.join(
                        result_path, str(i / args.image_save_interval))

                    if os.path.exists(subdir_path):
                        pass
                    else:
                        os.makedirs(subdir_path)

                    for im_idx in range(3):
                        # im_or = (inv_normalize([im_idx]).cpu().data.numpy().transpose(1,2,0))*255
                        # im = (inv_normalize([im_idx]).cpu().data.numpy().transpose(1,2,0))*255
                        im_or = (images[im_idx].cpu().data.numpy().transpose(
                            1, 2, 0) / 2 + .5) * 255
                        im = (img_out[im_idx].cpu().data.numpy().transpose(
                            1, 2, 0) / 2 + .5) * 255
                        # im = img_out[im_idx].cpu().data.numpy().transpose(1,2,0)*255

                        filename_prefix = os.path.join(subdir_path,
                                                       str(im_idx))
                        scipy.misc.imsave(filename_prefix + '_original.A.jpg',
                                          im_or)
                        scipy.misc.imsave(filename_prefix + '.A.jpg', im)

                        txt_or = " ".join([
                            vocab.idx2word[c]
                            for c in captions[im_idx].cpu().data.numpy()
                        ])
                        _, generated = torch.topk(txt_out[im_idx], 1)
                        txt = " ".join([
                            vocab.idx2word[c]
                            for c in generated[:, 0].cpu().data.numpy()
                        ])

                        with open(filename_prefix + "_captions.txt",
                                  "w") as text_file:
                            text_file.write("Original: %s\n" % txt_or)
                            text_file.write("Generated: %s" % txt)

                # measure elapsed time
                batch_time.update(time.time() - end)
                end = time.time()

                # plot progress
                bar.suffix = '({batch}/{size}) Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss_Img: {img_l:.3f}| Loss_Txt: {txt_l:.3f} | Loss_CM: {cm_l:.4f}'.format(
                    batch=i,
                    size=len(data_loader),
                    bt=batch_time.avg,
                    total=bar.elapsed_td,
                    eta=bar.eta_td,
                    img_l=img_losses.avg,
                    txt_l=txt_losses.avg,
                    cm_l=cm_losses.avg,
                )
                bar.next()

            # </editor-fold desc = "Logging">

            bar.finish()

            # <editor-fold desc = "Saving the models"?
            # Save the models
            print('\n')
            print('Saving the models in {}...'.format(model_path))
            torch.save(
                vae_Img.state_dict(),
                os.path.join(model_path, 'vae-img-%d-' % (epoch + 1)) + ".pkl")
            torch.save(
                vae_Txt.state_dict(),
                os.path.join(model_path, 'vae-txt-%d-' % (epoch + 1)) + ".pkl")

            # </editor-fold desc = "Saving the models"?

        if args.validate == "true":
            validate(vae_Img, vae_Txt, val_loader, mask, 10)