import argparse
import numpy as np
import pickle
from copy import deepcopy

from data_loader import DataLoader
from representations import *
from algorithms import *

if os.path.exists("train.pkl"):
    with open("train.pkl", 'rb') as f:
        train_corpus = pickle.load(f)
else:
    train_base_dir = "./data/aclImdb/train/"
    loader = DataLoader(train_base_dir)
    pos, neg = loader.gen_data()
    train_corpus = pos + neg
    with open('train.pkl', 'wb') as f:
        pickle.dump(train_corpus, f)

if os.path.exists("test.pkl"):
    with open("test.pkl", 'rb') as f:
        test_corpus = pickle.load(f)
else:
    test_base_dir = "./data/aclImdb/test/"
    loader = DataLoader(test_base_dir)
    pos, neg = loader.gen_data()
    test_corpus = pos + neg
    with open('test.pkl', 'wb') as f:
        pickle.dump(test_corpus, f)
示例#2
0
def main(args):
    # Create model directory
    threshold = 20
    captions_dict = load_captions(train_dir)
    vocab = Vocabulary(captions_dict, threshold)
    vocab_size = vocab.index

    if not os.path.exists(args.model_path):
        os.makedirs(args.model_path)

    # Image preprocessing, normalization for the pretrained resnet
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    # Load vocabulary wrapper
    #with open(args.vocab_path, 'rb') as f:
    #vocab = pickle.load(f)
    dataloader = DataLoader(train_dir, vocab, transform)
    imagenumbers, captiontotal, imagetotal = dataloader.gen_data()

    # Build data loader
    data_loader = get_loader(imagenumbers,
                             captiontotal,
                             imagetotal,
                             args.batch_size,
                             shuffle=True,
                             num_workers=args.num_workers)

    # Build the models
    encoder = EncoderCNN(args.embed_size).to(device)
    decoder = DecoderRNN(args.embed_size, args.hidden_size, vocab_size,
                         args.num_layers).to(device)

    # Loss and optimizer
    criterion = nn.CrossEntropyLoss()
    params = list(decoder.parameters()) + list(
        encoder.linear.parameters()) + list(encoder.bn.parameters())
    optimizer = torch.optim.Adam(params, lr=args.learning_rate)

    # Train the models
    total_step = len(data_loader)
    for epoch in range(args.num_epochs):
        for i, (images, captions, lengths) in enumerate(data_loader):

            # Set mini-batch dataset
            images = images.to(device)
            captions = captions.to(device)
            targets = pack_padded_sequence(captions, lengths,
                                           batch_first=True)[0]

            # Forward, backward and optimize
            features = encoder(images)
            outputs = decoder(features, captions, lengths)
            loss = criterion(outputs, targets)
            decoder.zero_grad()
            encoder.zero_grad()
            loss.backward()
            optimizer.step()

            # Print log info
            if i % args.log_step == 0:
                print(
                    'Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}, Perplexity: {:5.4f}'
                    .format(epoch, args.num_epochs, i, total_step, loss.item(),
                            np.exp(loss.item())))

            # Save the model checkpoints
            if (i + 1) % args.save_step == 0:
                torch.save(
                    decoder.state_dict(),
                    os.path.join(args.model_path,
                                 'decoder-{}-{}.ckpt'.format(epoch + 1,
                                                             i + 1)))
                torch.save(
                    encoder.state_dict(),
                    os.path.join(args.model_path,
                                 'encoder-{}-{}.ckpt'.format(epoch + 1,
                                                             i + 1)))
示例#3
0
class Solver:
    def __init__(self, args, id):
        self.args = args
        self.data_loader = DataLoader(args)
        self.data_loader.gen_data('train')
        self.data_loader.gen_data('validation')
        self.data_loader.gen_data('test')
        self.data_loader.gen_data('train_validation')
        self.model_args = self.args

        self.model_args.user_number = self.data_loader.user_num
        self.model_args.item_number = self.data_loader.item_num
        self.model_args.word_number = self.data_loader.word_num
        self.model_args.max_interaction_length = self.data_loader.max_interaction_length
        self.model_args.max_sentence_length = self.data_loader.max_sentence_length
        self.model_args.max_sentence_word_length = self.data_loader.max_sentence_word_length
        self.model_args.time_bin_number = self.data_loader.time_bin_number
        self.model_args.global_rating = self.data_loader.global_rating
        self.experiment_id = id

        self.model = DER(self.model_args)
        self.model.build_loss()
        print self.model_args
        # Get loss, evaluation, operation, rating and attention prediction.
        self.loss = self.model.get_loss()
        self.evaluation_mse_sum = self.model.get_mse_sum()
        self.op = self.model.get_train_op()
        self.rating_prediction = self.model.get_rating_prediction()
        self.attention_prediction = self.model.get_attention_weight()
        self.item_reviews_dict = self.data_loader.item_reviews_dict
        self.item_real_reviews_dict = self.data_loader.item_real_reviews_dict

    def save_parameters(self):
        path = os.path.join(self.args.output_path, str(self.experiment_id))
        if not os.path.exists(path):
            os.makedirs(path)
        self.args.logger.info('saving parameters ...')

        arg_dict = dict()
        for k, v in self.model_args.__dict__.items():
            if k != 'logger':
                arg_dict[k] = v

        pickle.dump(
            arg_dict,
            open(
                os.path.join(self.args.output_path, str(self.experiment_id),
                             'model_args'), 'wb'))
        t = pd.DataFrame(self.train_rmse_vs_epoch)
        t.to_csv(os.path.join(self.args.output_path, str(self.experiment_id),
                              'train_rmse_vs_epoch'),
                 index=False,
                 header=None)
        t = pd.DataFrame(self.validation_rmse_vs_epoch)
        t.to_csv(os.path.join(self.args.output_path, str(self.experiment_id),
                              'validation_rmse_vs_epoch'),
                 index=False,
                 header=None)

    def save_attention(self):
        self.args.logger.info('saving attention ...')
        self.step = 0
        self.result = []
        max_step = self.data_loader.test_records_num / self.args.batch_size
        while self.step <= max_step:
            if (self.step + 1
                ) * self.args.batch_size > self.data_loader.test_records_num:
                b = self.data_loader.test_records_num - self.step * self.args.batch_size
            else:
                b = self.args.batch_size

            start = self.step * self.args.batch_size
            end = start + b
            batch_users, batch_previous_items, batch_previous_times, batch_previous_reviews, \
            batch_previous_ratings, batch_previous_lengths, batch_current_items, batch_current_ratings, \
            batch_current_input_reviews, batch_current_input_reviews_users, batch_current_output_review, batch_current_input_reviews_length \
                = self.data_loader.gen_batch_data(start, end, 'test')
            self.step += 1

            att, a, b, c = self.sess.run(
                [
                    self.model.item_review_attention, self.model.out_att,
                    self.model.review, self.model.user_att
                ],
                feed_dict={
                    self.model.user_plh:
                    batch_users,
                    self.model.previous_items_plh:
                    batch_previous_items,
                    self.model.previous_times_plh:
                    batch_previous_times,
                    self.model.previous_reviews_plh:
                    batch_previous_reviews,
                    self.model.previous_ratings_plh:
                    batch_previous_ratings,
                    self.model.previous_lengths_plh:
                    batch_previous_lengths,
                    self.model.current_item_plh:
                    batch_current_items,
                    self.model.current_rating_plh:
                    batch_current_ratings,
                    self.model.current_input_reviews_plh:
                    batch_current_input_reviews,
                    self.model.current_input_reviews_users_plh:
                    batch_current_input_reviews_users,
                    self.model.current_input_reviews_length_plh:
                    batch_current_input_reviews_length
                })
            '''
            print np.array(att).shape
            print att[0]
            print att[1]
            print np.array(a).shape
            print a[0]
            print np.array(b).shape
            print b[0]
            print np.array(c).shape
            print c
            print d
            raw_input()
            '''

            for i in range(len(att)):
                user = batch_users[i]
                item = batch_current_items[i]
                attention = '@@@'.join([str(k[0]) for k in att[i]])
                reviews = [
                    j[1] for j in self.item_real_reviews_dict[str(
                        batch_current_items[i])] if j[0] != str(batch_users[i])
                ]
                reviews = list(itertools.chain.from_iterable(
                    reviews))[:self.model_args.max_sentence_length]
                reviews = '@@@'.join(reviews)
                line = str(user) + '||' + str(
                    item) + '||' + attention + '||' + reviews
                self.result.append(line)

        path = os.path.join(self.args.output_path, str(self.experiment_id))
        if not os.path.exists(path):
            os.makedirs(path)
        t = pd.DataFrame(self.result)
        t.to_csv(os.path.join(self.args.output_path, str(self.experiment_id),
                              'attention_results'),
                 index=False,
                 header=None)

    def run(self):
        self.args.logger.info('running ...')
        conf = tf.ConfigProto(allow_soft_placement=True)
        conf.gpu_options.allow_growth = True
        self.min_rmse = 1000000
        self.test_min_rmse = 1000000
        self.train_rmse_vs_epoch = []
        self.validation_rmse_vs_epoch = []
        self.test_rmse_vs_epoch = []

        with tf.Session(config=conf) as self.sess:
            self.model.model_init(self.sess)
            for iter in range(self.args.epoch):
                self.args.logger.info('*************************************')
                self.args.logger.info('epoch: ' + str(iter) + ' begin.')
                s = time.time()
                self.data_loader.shuffle()
                t_loss = 0.0
                self.step = 0
                if self.model_args.mode == 'validation':
                    max_step = self.data_loader.train_records_num / self.args.batch_size
                    record_number = self.data_loader.train_records_num
                else:
                    max_step = self.data_loader.train_validation_records_num / self.args.batch_size
                    record_number = self.data_loader.train_validation_records_num
                while self.step <= max_step:
                    if (self.step + 1) * self.args.batch_size > record_number:
                        b = record_number - self.step * self.args.batch_size
                    else:
                        b = self.args.batch_size

                    start = self.step * self.args.batch_size
                    end = start + b
                    if end > start:
                        if self.model_args.mode == 'validation':
                            batch_users, batch_previous_items, batch_previous_times, batch_previous_reviews, \
                            batch_previous_ratings, batch_previous_lengths, batch_current_items, batch_current_ratings, \
                            batch_current_input_reviews, batch_current_input_reviews_users, batch_current_output_review, batch_current_input_reviews_length \
                                = self.data_loader.gen_batch_data(start, end, 'train')
                        else:
                            batch_users, batch_previous_items, batch_previous_times, batch_previous_reviews, \
                            batch_previous_ratings, batch_previous_lengths, batch_current_items, batch_current_ratings, \
                            batch_current_input_reviews, batch_current_input_reviews_users, batch_current_output_review, batch_current_input_reviews_length \
                                = self.data_loader.gen_batch_data(start, end, 'train_validation')

                        self.sess.run(
                            self.op,
                            feed_dict={
                                self.model.user_plh:
                                batch_users,
                                self.model.previous_items_plh:
                                batch_previous_items,
                                self.model.previous_times_plh:
                                batch_previous_times,
                                self.model.previous_reviews_plh:
                                batch_previous_reviews,
                                self.model.previous_ratings_plh:
                                batch_previous_ratings,
                                self.model.previous_lengths_plh:
                                batch_previous_lengths,
                                self.model.current_item_plh:
                                batch_current_items,
                                self.model.current_rating_plh:
                                batch_current_ratings,
                                self.model.current_input_reviews_plh:
                                batch_current_input_reviews,
                                self.model.current_input_reviews_users_plh:
                                batch_current_input_reviews_users,
                                self.model.current_input_reviews_length_plh:
                                batch_current_input_reviews_length
                            })
                        self.step += 1

                time_consuming = str(time.time() - s)
                self.args.logger.info('epoch: ' + str(iter) +
                                      ' end. time consuming: ' +
                                      time_consuming)

                self.args.logger.info('epoch: ' + str(iter) + ' eval begin.')
                s = time.time()
                mse_sum = 0.0
                self.step = 0
                max_step = self.data_loader.train_records_num / self.args.batch_size
                while self.step <= max_step:
                    if (
                            self.step + 1
                    ) * self.args.batch_size > self.data_loader.train_records_num:
                        b = self.data_loader.train_records_num - self.step * self.args.batch_size
                    else:
                        b = self.args.batch_size
                    start = self.step * self.args.batch_size
                    end = start + b
                    if end > start:
                        batch_users, batch_previous_items, batch_previous_times, batch_previous_reviews, \
                        batch_previous_ratings, batch_previous_lengths, batch_current_items, batch_current_ratings, \
                        batch_current_input_reviews, batch_current_input_reviews_users, batch_current_output_review, batch_current_input_reviews_length \
                            = self.data_loader.gen_batch_data(start, end, 'train')
                        loss = self.sess.run(
                            self.evaluation_mse_sum,
                            feed_dict={
                                self.model.user_plh:
                                batch_users,
                                self.model.previous_items_plh:
                                batch_previous_items,
                                self.model.previous_times_plh:
                                batch_previous_times,
                                self.model.previous_reviews_plh:
                                batch_previous_reviews,
                                self.model.previous_ratings_plh:
                                batch_previous_ratings,
                                self.model.previous_lengths_plh:
                                batch_previous_lengths,
                                self.model.current_item_plh:
                                batch_current_items,
                                self.model.current_rating_plh:
                                batch_current_ratings,
                                self.model.current_input_reviews_plh:
                                batch_current_input_reviews,
                                self.model.current_input_reviews_users_plh:
                                batch_current_input_reviews_users,
                                self.model.current_input_reviews_length_plh:
                                batch_current_input_reviews_length
                            })
                        mse_sum += np.array(loss).sum()
                        self.step += 1

                rmse = np.sqrt(mse_sum / self.data_loader.train_records_num)
                self.train_rmse_vs_epoch.append(rmse)
                self.args.logger.info('epoch: ' + str(iter) +
                                      ' training loss: ' + str(rmse))

                mse_sum = 0.0
                self.step = 0
                max_step = self.data_loader.validation_records_num / self.args.batch_size
                while self.step <= max_step:
                    if (
                            self.step + 1
                    ) * self.args.batch_size > self.data_loader.validation_records_num:
                        b = self.data_loader.validation_records_num - self.step * self.args.batch_size
                    else:
                        b = self.args.batch_size

                    start = self.step * self.args.batch_size
                    end = start + b
                    if end > start:
                        batch_users, batch_previous_items, batch_previous_times, batch_previous_reviews, \
                        batch_previous_ratings, batch_previous_lengths, batch_current_items, batch_current_ratings, \
                        batch_current_input_reviews, batch_current_input_reviews_users, batch_current_output_review, batch_current_input_reviews_length \
                            = self.data_loader.gen_batch_data(start, end, 'validation')
                        error = self.sess.run(
                            self.evaluation_mse_sum,
                            feed_dict={
                                self.model.user_plh:
                                batch_users,
                                self.model.previous_items_plh:
                                batch_previous_items,
                                self.model.previous_times_plh:
                                batch_previous_times,
                                self.model.previous_reviews_plh:
                                batch_previous_reviews,
                                self.model.previous_ratings_plh:
                                batch_previous_ratings,
                                self.model.previous_lengths_plh:
                                batch_previous_lengths,
                                self.model.current_item_plh:
                                batch_current_items,
                                self.model.current_rating_plh:
                                batch_current_ratings,
                                self.model.current_input_reviews_plh:
                                batch_current_input_reviews,
                                self.model.current_input_reviews_users_plh:
                                batch_current_input_reviews_users,
                                self.model.current_input_reviews_length_plh:
                                batch_current_input_reviews_length
                            })
                        mse_sum += np.array(error).sum()
                        self.step += 1

                rmse = np.sqrt(mse_sum /
                               self.data_loader.validation_records_num)
                self.validation_rmse_vs_epoch.append(rmse)
                if rmse < self.min_rmse:
                    self.min_rmse = rmse
                    #self.save_attention()
                self.args.logger.info('epoch: ' + str(iter) +
                                      ' validation rmse: ' + str(rmse))
                self.args.logger.info('current best rmse: ' +
                                      str(self.min_rmse))

                mse_sum = 0.0
                self.step = 0
                max_step = self.data_loader.test_records_num / self.args.batch_size
                while self.step <= max_step:
                    if (
                            self.step + 1
                    ) * self.args.batch_size > self.data_loader.test_records_num:
                        b = self.data_loader.test_records_num - self.step * self.args.batch_size
                    else:
                        b = self.args.batch_size

                    start = self.step * self.args.batch_size
                    end = start + b
                    if end > start:
                        batch_users, batch_previous_items, batch_previous_times, batch_previous_reviews, \
                        batch_previous_ratings, batch_previous_lengths, batch_current_items, batch_current_ratings, \
                        batch_current_input_reviews, batch_current_input_reviews_users, batch_current_output_review, batch_current_input_reviews_length \
                            = self.data_loader.gen_batch_data(start, end, 'test')
                        error = self.sess.run(
                            self.evaluation_mse_sum,
                            feed_dict={
                                self.model.user_plh:
                                batch_users,
                                self.model.previous_items_plh:
                                batch_previous_items,
                                self.model.previous_times_plh:
                                batch_previous_times,
                                self.model.previous_reviews_plh:
                                batch_previous_reviews,
                                self.model.previous_ratings_plh:
                                batch_previous_ratings,
                                self.model.previous_lengths_plh:
                                batch_previous_lengths,
                                self.model.current_item_plh:
                                batch_current_items,
                                self.model.current_rating_plh:
                                batch_current_ratings,
                                self.model.current_input_reviews_plh:
                                batch_current_input_reviews,
                                self.model.current_input_reviews_users_plh:
                                batch_current_input_reviews_users,
                                self.model.current_input_reviews_length_plh:
                                batch_current_input_reviews_length
                            })
                        # print error
                        mse_sum += np.array(error).sum()
                        self.step += 1
                rmse = np.sqrt(mse_sum / self.data_loader.test_records_num)
                self.test_rmse_vs_epoch.append(rmse)
                if rmse < self.test_min_rmse:
                    self.test_min_rmse = rmse
                    self.save_attention()

                self.args.logger.info('epoch: ' + str(iter) + ' test rmse: ' +
                                      str(rmse))
                self.args.logger.info('current best test rmse: ' +
                                      str(self.test_min_rmse))
                time_consuming = str(time.time() - s)
                self.args.logger.info('epoch: ' + str(iter) +
                                      ' eval end. time consuming: ' +
                                      time_consuming)
                self.args.logger.info('*************************************')

            self.save_parameters()
            self.sess.close()
            if self.model_args.mode == 'validation':
                return self.min_rmse
            else:
                return self.test_min_rmse
示例#4
0
def main(args):
    threshold = 20
    captions_dict = load_captions(train_dir)
    vocab = Vocabulary(captions_dict, threshold)
    vocab_size = vocab.index
    # Image preprocessing, normalization for the pretrained resnet
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    dataloader = DataLoader(val_dir, vocab, transform)
    imagenumbers, captiontotal, imagetotal = dataloader.gen_data()

    # Build data loader
    data_loader = get_loader(imagenumbers,
                             captiontotal,
                             imagetotal,
                             args.batch_size,
                             shuffle=True,
                             num_workers=args.num_workers)

    # Build models
    encoder = EncoderCNN(args.embed_size).eval()
    decoder = DecoderRNN(args.embed_size, args.hidden_size, vocab_size,
                         args.num_layers)
    encoder = encoder.to(device)
    decoder = decoder.to(device)

    # Load the trained model parameters
    encoder.load_state_dict(torch.load(encoder_path))
    decoder.load_state_dict(torch.load(decoder_path))

    # Build data loader

    total_step = len(data_loader)

    # List to score the BLEU scores
    bleu_scores = []

    for i, (images, captions, lengths) in enumerate(data_loader):

        # Set mini-batch dataset
        images = images.to(device)
        # captions = captions.to(device)

        # Generate an caption from the image
        feature = encoder(images)
        sampled_ids = decoder.sample(feature)
        sampled_ids = sampled_ids[0].cpu().numpy()

        # Convert word_ids to words
        sampled_caption = []
        for word_id in sampled_ids:
            word = vocab.get_word(word_id)
            sampled_caption.append(word)
            if word == '<end>':
                break
        sentence = ' '.join(sampled_caption)

        score = sentence_bleu([captions], sentence, args.bleu_weights)
        bleu_scores.append(score)

        # Print log info
        if i % args.log_step == 0:
            print('Finish [{}/{}], Current BLEU Score: {:.4f}'.format(
                i, total_step, np.mean(bleu_scores)))
            print(sentence)
            print(captions)

    np.save('test_results.npy', [bleu_scores, np.mean(bleu_scores)])