def main(args):
    data = DataLoader(pca=args.PCA, norm=args.norm)

    train_captions, train_feature, train_url, train_len = data.get_Training_data(
        args.training)
    test_captions, test_feature, test_url, test_len = data.get_val_data(
        args.testing)
    f, c, _ = data.eval_data()

    writer = SummaryWriter()

    encoder = Encoder(input_size=train_feature.shape[1],
                      hidden_size=args.hidden_size) \
        .to(device)

    decoder = Decoder(embed_size=args.embed_size,
                      hidden_size=args.hidden_size, attention_dim=args.attention_size,
                      vocab_size=len(data.word_to_idx)) \
        .to(device)

    if args.load_weight:
        load_weights(encoder, args.model_path + "Jul28_10-04-57encoder")
        load_weights(decoder, args.model_path + "Jul28_10-04-57decoder")

    for epoch in range(args.num_epochs):
        params = list(decoder.parameters()) + list(encoder.parameters())
        criterion = nn.CrossEntropyLoss()
        optimizer = torch.optim.Adam(params=params, lr=args.learning_rate)

        # if epoch >= 100:
        training_loss = step(encoder=encoder,
                             decoder=decoder,
                             criterion=criterion,
                             data=(train_captions, train_feature, train_len),
                             optimizer=optimizer)
        # if epoch + 1 % 5 == 0:
        #     a = evaluate(encoder, decoder, train_feature[0:2], train_captions[0:2], 5, data.word_to_idx)
        #     print("bleu4 ", a)

        with torch.no_grad():
            test_loss = step(encoder=encoder,
                             decoder=decoder,
                             criterion=criterion,
                             data=(test_captions, test_feature, test_len))

        # if epoch > 1:
        b1, b2, b3, b4 = evaluate(encoder, decoder, f, c, 5, data.word_to_idx,
                                  data.idx_to_word)
        writer.add_scalars('BLEU', {
            'BLEU1': b1,
            'BLEU2': b2,
            'BLEU3': b3,
            'BLEU4': b4
        }, epoch + 1)
        if (epoch % 30) == 0:
            save_weights(encoder, args.model_path + "encoder" + str(epoch))
            save_weights(decoder, args.model_path + "decoder" + str(epoch))

        writer.add_scalars('loss', {
            'train': training_loss,
            'val': test_loss
        }, epoch + 1)

        print(
            'Epoch [{}/{}], Loss: {:.4f}, Perplexity: {:5.4f}, TestLoss: {:.4f}, TestPerplexity: {:5.4f}'
            .format(epoch + 1, args.num_epochs, training_loss,
                    np.exp(training_loss), test_loss, np.exp(test_loss)))

        args.learning_rate *= 0.995
        if args.save_weight:
            save_weights(encoder, args.model_path + "encoder" + str(epoch))
            save_weights(decoder, args.model_path + "decoder" + str(epoch))

    if args.save_weight:
        save_weights(encoder, args.model_path + "encoder")
        save_weights(decoder, args.model_path + "decoder")

    if args.predict:

        sample = Sample(encoder=encoder, decoder=decoder, device=device)

        train_mask = [
            random.randint(0, train_captions.shape[0] - 1)
            for _ in range(args.numOfpredection)
        ]
        test_mask = [
            random.randint(0, test_captions.shape[0] - 1)
            for _ in range(args.numOfpredection)
        ]

        train_featur = torch.from_numpy(train_feature[train_mask])
        train_featur = train_featur.to(device)
        train_encoder_out = encoder(train_featur)

        test_featur = torch.from_numpy(test_feature[test_mask])
        test_featur = test_featur.to(device)
        test_encoder_out = encoder(test_featur)

        train_output = []
        test_output = []

        for i in range(len(test_mask)):
            print(i)
            pre = sample.caption_image_beam_search(
                train_encoder_out[i].reshape(1, args.embed_size),
                data.word_to_idx, 2)
            train_output.append(pre)
            pre = sample.caption_image_beam_search(
                test_encoder_out[i].reshape(1, args.embed_size),
                data.word_to_idx, 50)
            test_output.append(pre)

        print_output(output=test_output,
                     sample=0,
                     gt=test_captions[test_mask],
                     img=test_url[test_mask],
                     title="val",
                     show_image=args.show_image,
                     idx_to_word=data.idx_to_word)

        print("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX")
        print("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX")
        print("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX")
        print("")

        print_output(output=train_output,
                     sample=0,
                     gt=train_captions[train_mask],
                     img=train_url[train_mask],
                     title="traning",
                     show_image=args.show_image,
                     idx_to_word=data.idx_to_word)