예제 #1
0
def runTest(n_layers, hidden_size, reverse, modelFile, beam_size, input,
            corpus):

    voc, pairs, valid_pairs, test_pairs = loadPrepareData(corpus)

    print('Building encoder and decoder ...')
    '''# attribute embeddings
    attr_size = 64
    attr_num = 2
    with open(os.path.join(save_dir, 'user_item.pkl'), 'rb') as fp:
        user_dict, item_dict = pickle.load(fp)
    num_user = len(user_dict)
    num_item = len(item_dict)
    attr_embeddings = []
    attr_embeddings.append(nn.Embedding(num_user, attr_size))    
    attr_embeddings.append(nn.Embedding(num_item, attr_size)) 
    if USE_CUDA:
        for attr_embedding in attr_embeddings:
            attr_embedding = attr_embedding.cuda()
   
    encoder = AttributeEncoder(attr_size, attr_num, hidden_size, attr_embeddings, n_layers)
    '''
    embedding = nn.Embedding(voc.n_words, hidden_size,
                             padding_idx=0)  # word embedding
    encoder = EncoderRNN(voc.n_words, hidden_size, embedding, n_layers)
    attn_model = 'concat'
    decoder = DecoderRNN(embedding, hidden_size, voc.n_words, n_layers)

    checkpoint = torch.load(modelFile)
    encoder.load_state_dict(checkpoint['en'])
    decoder.load_state_dict(checkpoint['de'])
    # train mode set to false, effect only on dropout, batchNorm
    encoder.train(False)
    decoder.train(False)

    if USE_CUDA:
        encoder = encoder.cuda()
        decoder = decoder.cuda()

    evaluateRandomly(encoder, decoder, voc, pairs, reverse, beam_size, 2)
예제 #2
0
        train_data_loader.batch_sampler.sampler = new_sampler

        # Obtain the batch.
        images, captions = next(iter(train_data_loader))

        # Move batch of images and captions to GPU if CUDA is available.
        images = images.to(device)
        captions = captions.to(device)

        # Zero the gradients.
        decoder.zero_grad()
        encoder.zero_grad()

        # set the encoder decoder in training mode
        encoder.train()
        decoder.train()

        # Pass the inputs through the CNN-RNN model.
        features = encoder(images)
        outputs = decoder(features, captions)

        # Calculate the batch loss.
        loss = criterion(outputs.view(-1, vocab_size), captions.view(-1))

        # Backward pass.
        loss.backward()

        # Update the parameters in the optimizer.
        optimizer.step()

        with torch.no_grad():
예제 #3
0
def main(args):

    #setup tensorboard
    if args.tensorboard:
        cc = CrayonClient(hostname="localhost")
        print(cc.get_experiment_names())
        #if args.name in cc.get_experiment_names():
        try:
            cc.remove_experiment(args.name)
        except:
            print("experiment didnt exist")
        cc_server = cc.create_experiment(args.name)

    # Create model directory
    full_model_path = args.model_path + "/" + args.name
    if not os.path.exists(full_model_path):
        os.makedirs(full_model_path)
    with open(full_model_path + "/parameters.json", 'w') as f:
        f.write((json.dumps(vars(args))))

    # Image preprocessing

    transform = transforms.Compose([
        transforms.Scale(args.crop_size),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    mini_transform = transforms.Compose(
        [transforms.ToPILImage(),
         transforms.Scale(20),
         transforms.ToTensor()])

    # Load vocabulary wrapper.
    if args.vocab_path is not None:
        with open(args.vocab_path, 'rb') as f:
            vocab = pickle.load(f)
    else:
        print("building new vocab")
        vocab = build_vocab(args.image_dir, 1, None)
        with open((full_model_path + "/vocab.pkl"), 'wb') as f:
            pickle.dump(vocab, f)

    # Build data loader
    data_loader = get_loader(args.image_dir,
                             vocab,
                             transform,
                             args.batch_size,
                             shuffle=True,
                             num_workers=args.num_workers)
    code_data_set = ProcessingDataset(root=args.image_dir,
                                      vocab=vocab,
                                      transform=transform)
    train_ds, val_ds = validation_split(code_data_set)
    train_loader = torch.utils.data.DataLoader(train_ds, collate_fn=collate_fn)
    test_loader = torch.utils.data.DataLoader(val_ds, collate_fn=collate_fn)
    train_size = len(train_loader)
    test_size = len(test_loader)

    # Build the models
    encoder = EncoderCNN(args.embed_size, args.train_cnn)
    print(encoder)
    decoder = DecoderRNN(args.embed_size, args.hidden_size, len(vocab),
                         args.num_layers)
    print(decoder)
    if torch.cuda.is_available():
        encoder.cuda()
        decoder.cuda()

    # Loss and Optimizer
    criterion = nn.CrossEntropyLoss()
    params = list(decoder.parameters()) + list(
        encoder.linear.parameters()) + list(encoder.bn.parameters())
    #params = list(decoder.parameters()) #+ list(encoder.linear.parameters()) + list(encoder.bn.parameters())
    optimizer = torch.optim.Adam(params, lr=args.learning_rate)
    start_time = time.time()
    add_log_entry(args.name, start_time, vars(args))

    # Train the Models
    total_step = len(data_loader)
    for epoch in range(args.num_epochs):
        for i, (images, captions, lengths) in enumerate(data_loader):
            decoder.train()
            encoder.train()
            # Set mini-batch dataset
            image_ts = to_var(images, volatile=True)
            captions = to_var(captions)
            targets = pack_padded_sequence(captions, lengths,
                                           batch_first=True)[0]
            count = images.size()[0]

            # Forward, Backward and Optimize
            decoder.zero_grad()
            encoder.zero_grad()
            features = encoder(image_ts)
            outputs = decoder(features, captions, lengths)

            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

            total = targets.size(0)
            max_index = outputs.max(dim=1)[1]
            #correct = (max_index == targets).sum()
            _, predicted = torch.max(outputs.data, 1)
            correct = predicted.eq(targets.data).cpu().sum()
            accuracy = 100. * correct / total

            if args.tensorboard:
                cc_server.add_scalar_value("train_loss", loss.data[0])
                cc_server.add_scalar_value("perplexity", np.exp(loss.data[0]))
                cc_server.add_scalar_value("accuracy", accuracy)

            # Print log info
            if i % args.log_step == 0:
                print(
                    'Epoch [%d/%d], Step [%d/%d], Loss: %.4f, accuracy: %2.2f Perplexity: %5.4f'
                    % (epoch, args.num_epochs, i, total_step, loss.data[0],
                       accuracy, np.exp(loss.data[0])))

            # Save the models
            if (i + 1) % args.save_step == 0:
                torch.save(
                    decoder.state_dict(),
                    os.path.join(full_model_path,
                                 'decoder-%d-%d.pkl' % (epoch + 1, i + 1)))
                torch.save(
                    encoder.state_dict(),
                    os.path.join(full_model_path,
                                 'encoder-%d-%d.pkl' % (epoch + 1, i + 1)))
            if 1 == 2 and i % int(train_size / 10) == 0:
                encoder.eval()
                #decoder.eval()
                correct = 0
                for ti, (timages, tcaptions,
                         tlengths) in enumerate(test_loader):
                    timage_ts = to_var(timages, volatile=True)
                    tcaptions = to_var(tcaptions)
                    ttargets = pack_padded_sequence(tcaptions,
                                                    tlengths,
                                                    batch_first=True)[0]
                    tfeatures = encoder(timage_ts)
                    toutputs = decoder(tfeatures, tcaptions, tlengths)
                    print(ttargets)
                    print(toutputs)
                    print(ttargets.size())
                    print(toutputs.size())
                    #correct = (ttargets.eq(toutputs[0].long())).sum()

                accuracy = 100 * correct / test_size
                print('accuracy: %.4f' % (accuracy))
                if args.tensorboard:
                    cc_server.add_scalar_value("accuracy", accuracy)

    torch.save(
        decoder.state_dict(),
        os.path.join(full_model_path,
                     'decoder-%d-%d.pkl' % (epoch + 1, i + 1)))
    torch.save(
        encoder.state_dict(),
        os.path.join(full_model_path,
                     'encoder-%d-%d.pkl' % (epoch + 1, i + 1)))
    end_time = time.time()
    print("finished training, runtime: %d", [(end_time - start_time)])
예제 #4
0
class ImageDescriptor():
    def __init__(self, args, encoder):
        assert(args.mode == 'train' or 'val' or 'test')
        self.__args = args
        self.__mode = args.mode
        self.__attention_mechanism = args.attention
        self.__stats_manager = ImageDescriptorStatsManager()
        self.__validate_when_training = args.validate_when_training
        self.__history = []

        if not os.path.exists(args.model_dir):
            os.makedirs(args.model_dir)

        self.__config_path = os.path.join(
            args.model_dir, f'config-{args.encoder}{args.encoder_ver}.txt')

        # Device configuration
        self.__device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')

        # training set vocab
        with open(args.vocab_path, 'rb') as f:
            self.__vocab = pickle.load(f)

        # validation set vocab
        with open(args.vocab_path.replace('train', 'val'), 'rb') as f:
            self.__vocab_val = pickle.load(f)

        # coco dataset
        self.__coco_train = CocoDataset(
            args.image_dir, args.caption_path, self.__vocab, args.crop_size)
        self.__coco_val = CocoDataset(
            args.image_dir, args.caption_path.replace('train', 'val'), self.__vocab_val, args.crop_size)

        # data loader
        self.__train_loader = torch.utils.data.DataLoader(dataset=self.__coco_train,
                                                          batch_size=args.batch_size,
                                                          shuffle=True,
                                                          num_workers=args.num_workers,
                                                          collate_fn=collate_fn)
        self.__val_loader = torch.utils.data.DataLoader(dataset=self.__coco_val,
                                                        batch_size=args.batch_size,
                                                        shuffle=False,
                                                        num_workers=args.num_workers,
                                                        collate_fn=collate_fn)
        # Build the models
        self.__encoder = encoder.to(self.__device)
        self.__decoder = DecoderRNN(args.embed_size, args.hidden_size,
                                    len(self.__vocab), args.num_layers, attention_mechanism=self.__attention_mechanism).to(self.__device)

        # Loss and optimizer
        self.__criterion = nn.CrossEntropyLoss()
        self.__params = list(self.__decoder.parameters(
        )) + list(self.__encoder.linear.parameters()) + list(self.__encoder.bn.parameters())
        self.__optimizer = torch.optim.Adam(
            self.__params, lr=args.learning_rate)

        # Load checkpoint and check compatibility
        if os.path.isfile(self.__config_path):
            with open(self.__config_path, 'r') as f:
                content = f.read()[:-1]
            if content != repr(self):
                # save the error info
                with open('config.err', 'w') as f:
                    print(f'f.read():\n{content}', file=f)
                    print(f'repr(self):\n{repr(self)}', file=f)
                raise ValueError(
                    "Cannot create this experiment: "
                    "I found a checkpoint conflicting with the current setting.")
            self.load(file_name=args.checkpoint)
        else:
            self.save()

    def setting(self):
        '''
        Return the setting of the experiment.
        '''
        return {'Net': (self.__encoder, self.__decoder),
                'Optimizer': self.__optimizer,
                'BatchSize': self.__args.batch_size}

    @property
    def epoch(self):
        return len(self.__history)

    @property
    def history(self):
        return self.__history

    # @property
    # def mode(self):
    #     return self.__args.mode

    # @mode.setter
    # def mode(self, m):
    #     self.__args.mode = m

    def __repr__(self):
        '''
        Pretty printer showing the setting of the experiment. This is what
        is displayed when doing `print(experiment). This is also what is
        saved in the `config.txt file.
        '''
        string = ''
        for key, val in self.setting().items():
            string += '{}({})\n'.format(key, val)
        return string

    def state_dict(self):
        '''
        Returns the current state of the model.
        '''
        return {'Net': (self.__encoder.state_dict(), self.__decoder.state_dict()),
                'Optimizer': self.__optimizer.state_dict(),
                'History': self.__history}

    def save(self):
        '''
        Saves the model on disk, i.e, create/update the last checkpoint.
        '''
        file_name = os.path.join(
            self.__args.model_dir, '{}{}-epoch-{}.ckpt'.format(self.__args.encoder, self.__args.encoder_ver, self.epoch))
        torch.save(self.state_dict(), file_name)
        with open(self.__config_path, 'w') as f:
            print(self, file=f)

        print(f'Save to {file_name}.')

    def load(self, file_name=None):
        '''
        Loads the model from the last checkpoint saved on disk.

        Args:
            file_name (str): path to the checkpoint file
        '''
        if not file_name:
            # find the latest .ckpt file
            try:
                file_name = max(
                    glob.iglob(os.path.join(self.__args.model_dir, '*.ckpt')), key=os.path.getctime)
                print(f'Load from {file_name}.')
            except:
                raise FileNotFoundError(
                    'No checkpoint file in the model directory.')
        else:
            file_name = os.path.join(self.__args.model_dir, file_name)
            print(f'Load from {file_name}.')

        try:
            checkpoint = torch.load(file_name, map_location=self.__device)
        except:
            raise FileNotFoundError(
                'Please check --checkpoint, the name of the file')

        self.load_state_dict(checkpoint)
        del checkpoint

    def load_state_dict(self, checkpoint):
        '''
        Loads the model from the input checkpoint.

        Args:
            checkpoint: an object saved with torch.save() from a file.
        '''
        self.__encoder.load_state_dict(checkpoint['Net'][0])
        self.__decoder.load_state_dict(checkpoint['Net'][1])
        self.__optimizer.load_state_dict(checkpoint['Optimizer'])
        self.__history = checkpoint['History']

        # The following loops are used to fix a bug that was
        # discussed here: https://github.com/pytorch/pytorch/issues/2830
        # (it is supposed to be fixed in recent PyTorch version)
        for state in self.__optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.to(self.__device)

    def train(self, plot_loss=None):
        '''
        Train the network using backpropagation based
        on the optimizer and the training set.

        Args:
            plot_loss (func, optional): if not None, should be a function taking a
                single argument being an experiment (meant to be `self`).
                Similar to a visitor pattern, this function is meant to inspect
                the current state of the experiment and display/plot/save
                statistics. For example, if the experiment is run from a
                Jupyter notebook, `plot` can be used to display the evolution
                of the loss with `matplotlib`. If the experiment is run on a
                server without display, `plot` can be used to show statistics
                on `stdout` or save statistics in a log file. (default: None)
        '''
        self.__encoder.train()
        self.__decoder.train()
        self.__stats_manager.init()
        total_step = len(self.__train_loader)
        start_epoch = self.epoch
        print("Start/Continue training from epoch {}".format(start_epoch))

        if plot_loss is not None:
            plot_loss(self)

        for epoch in range(start_epoch, self.__args.num_epochs):
            t_start = time.time()
            self.__stats_manager.init()
            for i, (images, captions, lengths) in enumerate(self.__train_loader):
                # Set mini-batch dataset
                if not self.__attention_mechanism:
                    images = images.to(self.__device)
                    captions = captions.to(self.__device)
                else:
                    with torch.no_grad():
                        images = images.to(self.__device)
                    captions = captions.to(self.__device)

                targets = pack_padded_sequence(
                    captions, lengths, batch_first=True)[0]

                # Forward, backward and optimize
                if not self.__attention_mechanism:
                    features = self.__encoder(images)
                    outputs = self.__decoder(features, captions, lengths)
                    self.__decoder.zero_grad()
                    self.__encoder.zero_grad()
                else:
                    self.__encoder.zero_grad()
                    self.__decoder.zero_grad()
                    features, cnn_features = self.__encoder(images)
                    outputs = self.__decoder(
                        features, captions, lengths, cnn_features=cnn_features)
                loss = self.__criterion(outputs, targets)

                loss.backward()
                self.__optimizer.step()
                with torch.no_grad():
                    self.__stats_manager.accumulate(
                        loss=loss.item(), perplexity=np.exp(loss.item()))

                # Print log info each iteration
                if i % self.__args.log_step == 0:
                    print('[Training] Epoch: {}/{} | Step: {}/{} | Loss: {:.4f} | Perplexity: {:5.4f}'
                          .format(epoch+1, self.__args.num_epochs, i, total_step, loss.item(), np.exp(loss.item())))

            if not self.__validate_when_training:
                self.__history.append(self.__stats_manager.summarize())
                print("Epoch {} | Time: {:.2f}s\nTraining Loss: {:.6f} | Training Perplexity: {:.6f}".format(
                    self.epoch, time.time() - t_start, self.__history[-1]['loss'], self.__history[-1]['perplexity']))
            else:
                self.__history.append(
                    (self.__stats_manager.summarize(), self.evaluate()))
                print("Epoch {} | Time: {:.2f}s\nTraining Loss: {:.6f} | Training Perplexity: {:.6f}\nEvaluation Loss: {:.6f} | Evaluation Perplexity: {:.6f}".format(
                    self.epoch, time.time() - t_start,
                    self.__history[-1][0]['loss'], self.__history[-1][0]['perplexity'],
                    self.__history[-1][1]['loss'], self.__history[-1][1]['perplexity']))

            # Save the model checkpoints
            self.save()

            if plot_loss is not None:
                plot_loss(self)

        print("Finish training for {} epochs".format(self.__args.num_epochs))

    def evaluate(self, print_info=False):
        '''
        Evaluates the experiment, i.e., forward propagates the validation set
        through the network and returns the statistics computed by the stats
        manager.

        Args:
            print_info (bool): print the results of loss and perplexity
        '''
        self.__stats_manager.init()
        self.__encoder.eval()
        self.__decoder.eval()
        total_step = len(self.__val_loader)
        with torch.no_grad():
            for i, (images, captions, lengths) in enumerate(self.__val_loader):
                images = images.to(self.__device)
                captions = captions.to(self.__device)
                targets = pack_padded_sequence(
                    captions, lengths, batch_first=True)[0]

                # Forward
                if not self.__attention_mechanism:
                    features = self.__encoder(images)
                    outputs = self.__decoder(features, captions, lengths)
                else:
                    features, cnn_features = self.__encoder(images)
                    outputs = self.__decoder(
                        features, captions, lengths, cnn_features=cnn_features)
                loss = self.__criterion(outputs, targets)
                self.__stats_manager.accumulate(
                    loss=loss.item(), perplexity=np.exp(loss.item()))
                if i % self.__args.log_step == 0:
                    print('[Validation] Step: {}/{} | Loss: {:.4f} | Perplexity: {:5.4f}'
                          .format(i, total_step, loss.item(), np.exp(loss.item())))

        summarize = self.__stats_manager.summarize()
        if print_info:
            print(
                f'[Validation] Average loss for this epoch is {summarize["loss"]:.6f}')
            print(
                f'[Validation] Average perplexity for this epoch is {summarize["perplexity"]:.6f}\n')
        self.__encoder.train()
        self.__decoder.train()
        return summarize

    def mode(self, mode=None):
        '''
        Get the current mode or change mode.

        Args:
            mode (str): 'train' or 'eval' mode
        '''
        if not mode:
            return self.__mode
        self.__mode = mode

    def __load_image(self, image):
        '''
        Load image at `image_path` for evaluation.

        Args:
            image (PIL Image): image
        '''
        image = image.resize([224, 224], Image.LANCZOS)

        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406),
                                 (0.229, 0.224, 0.225))])
        image = transform(image).unsqueeze(0)

        return image

    def test(self, image_path=None, plot=False):
        '''
        Evaluate the model by generating the caption for the
        corresponding image at `image_path`.

        Note: This function will not provide BLEU socre.

        Args:
            image_path (str): file path of the evaluation image
            plot (bool): plot or not
        '''
        self.__encoder.eval()
        self.__decoder.eval()

        with torch.no_grad():
            if not image_path:
                image_path = self.__args.image_path

            image = Image.open(image_path)

            # only process with RGB image
            if np.array(image).ndim == 3:
                img = self.__load_image(image).to(self.__device)

                # generate an caption
                if not self.__attention_mechanism:
                    feature = self.__encoder(img)
                    sampled_ids = self.__decoder.sample(feature)
                    sampled_ids = sampled_ids[0].cpu().numpy()
                else:
                    feature, cnn_features = self.__encoder(img)
                    sampled_ids = self.__decoder.sample(feature, cnn_features)
                    sampled_ids = sampled_ids.cpu().data.numpy()

                # Convert word_ids to words
                sampled_caption = []
                for word_id in sampled_ids:
                    word = self.__vocab.idx2word[word_id]
                    sampled_caption.append(word)
                    if word == '<end>':
                        break
                sentence = ' '.join(sampled_caption[1:-1])

                # Print out the image and the generated caption
                print(sentence)

                if plot:
                    image = Image.open(image_path)
                    plt.imshow(np.asarray(image))
            else:
                print('Not support for non-RGB image.')
        self.__encoder.train()
        self.__decoder.train()

    def coco_image(self, idx, ds='val'):
        '''
        Access iamge_id (which is part of the file name) 
        and corresponding image caption of index `idx` in COCO dataset.

        Note: For jupyter notebook

        Args:
            idx (int): index of COCO dataset

        Returns:
            (dict)
        '''
        assert(ds == 'train' or 'val')

        if ds == 'train':
            ann_id = self.__coco_train.ids[idx]
            return self.__coco_train.coco.anns[ann_id]
        else:
            ann_id = self.__coco_val.ids[idx]
            return self.__coco_val.coco.anns[ann_id]

    @property
    def len_of_train_set(self):
        '''
        Number of training 
        '''
        return len(self.__coco_train)

    @property
    def len_of_val_set(self):
        return len(self.__coco_val)

    def bleu_score(self, idx, ds='val', plot=False, show_caption=False):
        '''
        Evaluate the BLEU score for index `idx` in COCO dataset.

        Note: For jupyter notebook

        Args:
            idx (int): index
            ds (str): training or validation dataset
            plot (bool): plot the image or not

        Returns:
            score (float): bleu score
        '''
        assert(ds == 'train' or 'val')
        self.__encoder.eval()
        self.__decoder.eval()

        with torch.no_grad():
            try:
                if ds == 'train':
                    ann_id = self.__coco_train.ids[idx]
                    coco_ann = self.__coco_train.coco.anns[ann_id]
                else:
                    ann_id = self.__coco_val.ids[idx]
                    coco_ann = self.__coco_val.coco.anns[ann_id]
            except:
                raise IndexError('Invalid index')

            image_id = coco_ann['image_id']

            image_id = str(image_id)
            if len(image_id) != 6:
                for _ in range(6 - len(image_id)):
                    image_id = '0' + image_id

            image_path = f'{self.__args.image_dir}/COCO_train2014_000000{image_id}.jpg'
            if ds == 'val':
                image_path = image_path.replace('train', 'val')

            coco_list = coco_ann['caption'].split()

            image = Image.open(image_path)

            if np.array(image).ndim == 3:
                img = self.__load_image(image).to(self.__device)

                # generate an caption
                if not self.__attention_mechanism:
                    feature = self.__encoder(img)
                    sampled_ids = self.__decoder.sample(feature)
                    sampled_ids = sampled_ids[0].cpu().numpy()
                else:
                    feature, cnn_features = self.__encoder(img)
                    sampled_ids = self.__decoder.sample(feature, cnn_features)
                    sampled_ids = sampled_ids.cpu().data.numpy()

                # Convert word_ids to words
                sampled_caption = []
                for word_id in sampled_ids:
                    word = self.__vocab.idx2word[word_id]
                    sampled_caption.append(word)
                    if word == '<end>':
                        break

                # strip punctuations and spacing
                sampled_list = [c for c in sampled_caption[1:-1]
                                if c not in punctuation]

                score = sentence_bleu(coco_list, sampled_list,
                                      smoothing_function=SmoothingFunction().method4)

                if plot:
                    plt.figure()
                    image = Image.open(image_path)
                    plt.imshow(np.asarray(image))
                    plt.title(f'score: {score}')
                    plt.xlabel(f'file: {image_path}')

                # Print out the generated caption
                if show_caption:
                    print(f'Sampled caption:\n{sampled_list}')
                    print(f'COCO caption:\n{coco_list}')

            else:
                print('Not support for non-RGB image.')
                return

        return score
예제 #5
0
def main(args):
    # Create model directory
    if not os.path.exists(args.model_path):
        os.makedirs(args.model_path)

    # Image preprocessing, normalization for the pretrained resnet
    transform = transforms.Compose([
        transforms.CenterCrop(args.crop_size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])
    transform_val = transforms.Compose([
        transforms.CenterCrop(args.crop_size),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])

    # Load vocabulary wrapper
    with open(args.vocab_path, 'rb') as f:
        vocab = pickle.load(f)

    # Build data loader
    data_loader = get_loader(args.image_dir,
                             args.caption_path,
                             vocab,
                             transform,
                             args.batch_size,
                             shuffle=True,
                             num_workers=args.num_workers)
    val_loader = get_loader(args.val_dir,
                            args.val_caption_path,
                            vocab,
                            transform_val,
                            args.batch_size,
                            shuffle=False,
                            num_workers=args.num_workers)
    # Build the models
    encoder = EncoderCNN(args.embed_size).to(device)
    encoder.freeze_bottom()
    decoder = DecoderRNN(args.embed_size, args.hidden_size, len(vocab),
                         args.num_layers).to(device)
    #     decoder = BahdanauAttnDecoderRNN(args.hidden_size, args.embed_size, len(vocab)).to(device)

    # Loss and optimizer
    criterion = nn.CrossEntropyLoss()
    params = list(decoder.parameters()) + list(
        encoder.linear.parameters()) + list(encoder.bn.parameters())
    optimizer = torch.optim.Adam(params, lr=args.learning_rate)

    # Train the models
    total_step = len(data_loader)
    accs, b1s, b2s, b3s, b4s = [], [], [], [], []
    for epoch in range(args.num_epochs):
        decoder.train()
        encoder.train()
        losses = []
        for i, (images, captions, lengths) in enumerate(data_loader):

            # Set mini-batch dataset
            images = images.to(device)
            captions = captions.to(device)
            targets = pack_padded_sequence(captions, lengths,
                                           batch_first=True)[0]
            # Forward, backward and optimize
            features = encoder(images)
            outputs = decoder(features, captions, lengths)
            loss = criterion(outputs, targets)
            losses.append(loss.item())
            decoder.zero_grad()
            encoder.zero_grad()
            loss.backward()
            optimizer.step()

            # Print log info
            if i % args.log_step == 0:
                print(
                    'Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}, Perplexity: {:5.4f}'
                    .format(epoch + 1, args.num_epochs, i, total_step,
                            loss.item(), np.exp(loss.item())))

            # Save the model checkpoints
            if (i + 1) % args.save_step == 0:
                torch.save(
                    decoder.state_dict(),
                    os.path.join(args.model_path,
                                 'decoder-{}-{}.ckpt'.format(epoch + 1,
                                                             i + 1)))
                torch.save(
                    encoder.state_dict(),
                    os.path.join(args.model_path,
                                 'encoder-{}-{}.ckpt'.format(epoch + 1,
                                                             i + 1)))


#         acc, b1, b2, b3, b4 = evaluate(val_loader, encoder, decoder, vocab)
#         accs.append(acc)
#         b1s.append(b1)
#         b2s.append(b2)
#         b3s.append(b3)
#         b4s.append(b4)
        avg_loss = sum(losses) / total_step

        print('Epoch {} Average Training Loss: {:.4f}'.format(
            epoch + 1, avg_loss))

        with open('stem_freeze_freq1000.txt', 'a') as file:
            file.write("Epoch {} \n".format(epoch + 1))
            file.write('Average Accuracy: {} \n'.format(acc))
            file.write('Average Loss: {} \n'.format(avg_loss))
            file.write('Average BLEU gram1: {} \n'.format(b1))
            file.write('Average BLEU gram2: {} \n'.format(b2))
            file.write('Average BLEU gram3: {} \n'.format(b3))
            file.write('Average BLEU gram4: {} \n'.format(b4))
            file.write('\n')

    plt.title("Accuracy vs BLEU score")
    plt.plot(np.arange(1, args.num_epochs + 1), accs, label='accuracy')
    plt.plot(np.arange(1, args.num_epochs + 1), b1s, label='BLEU 1')
    plt.plot(np.arange(1, args.num_epochs + 1), b2s, label='BLEU 2')
    plt.plot(np.arange(1, args.num_epochs + 1), b3s, label='BLEU 3')
    plt.plot(np.arange(1, args.num_epochs + 1), b4s, label='BLEU 4')
    plt.xlabel("epochs")
    plt.xticks(np.arange(1, args.num_epochs + 1))
    plt.legend(loc='upper left')
    plt.savefig('accuracy_BLEU.png')
    plt.clf()
예제 #6
0
def trainIters(corpus,
               reverse,
               n_epoch,
               learning_rate,
               batch_size,
               n_layers,
               hidden_size,
               print_every,
               loadFilename=None,
               attn_model='dot',
               decoder_learning_ratio=5.0):
    print(
        "corpus: {}, reverse={}, n_epoch={}, learning_rate={}, batch_size={}, n_layers={}, hidden_size={}, decoder_learning_ratio={}"
        .format(corpus, reverse, n_epoch, learning_rate, batch_size, n_layers,
                hidden_size, decoder_learning_ratio))

    voc, pairs, valid_pairs, test_pairs = loadPrepareData(corpus)
    print('load data...')

    path = "data/attr2seq"
    # training data
    corpus_name = corpus
    training_batches = None
    try:
        training_batches = torch.load(
            os.path.join(
                save_dir, path,
                '{}_{}.tar'.format(filename(reverse, 'training_batches'),
                                   batch_size)))
    except FileNotFoundError:
        print('Training pairs not found, generating ...')
        training_batches = batchify(pairs, batch_size, voc, reverse)
        print('Complete building training pairs ...')
        torch.save(
            training_batches,
            os.path.join(
                save_dir, path,
                '{}_{}.tar'.format(filename(reverse, 'training_batches'),
                                   batch_size)))

    # validation/test data
    eval_batch_size = 10
    try:
        val_batches = torch.load(
            os.path.join(
                save_dir, path,
                '{}_{}.tar'.format(filename(reverse, 'val_batches'),
                                   eval_batch_size)))
    except FileNotFoundError:
        print('Validation pairs not found, generating ...')
        val_batches = batchify(valid_pairs,
                               eval_batch_size,
                               voc,
                               reverse,
                               evaluation=True)
        print('Complete building validation pairs ...')
        torch.save(
            val_batches,
            os.path.join(
                save_dir, path,
                '{}_{}.tar'.format(filename(reverse, 'val_batches'),
                                   eval_batch_size)))

    try:
        test_batches = torch.load(
            os.path.join(
                save_dir, path,
                '{}_{}.tar'.format(filename(reverse, 'test_batches'),
                                   eval_batch_size)))
    except FileNotFoundError:
        print('Test pairs not found, generating ...')
        test_batches = batchify(test_pairs,
                                eval_batch_size,
                                voc,
                                reverse,
                                evaluation=True)
        print('Complete building test pairs ...')
        torch.save(
            test_batches,
            os.path.join(
                save_dir, path,
                '{}_{}.tar'.format(filename(reverse, 'test_batches'),
                                   eval_batch_size)))

    # model
    checkpoint = None
    print('Building encoder and decoder ...')
    embedding = nn.Embedding(voc.n_words, hidden_size)
    encoder = EncoderRNN(voc.n_words, hidden_size, embedding, n_layers)
    attn_model = 'dot'
    decoder = DecoderRNN(embedding, hidden_size, voc.n_words, n_layers)
    if loadFilename:
        checkpoint = torch.load(loadFilename)
        encoder.load_state_dict(checkpoint['en'])
        decoder.load_state_dict(checkpoint['de'])
    # use cuda
    if USE_CUDA:
        encoder = encoder.cuda()
        decoder = decoder.cuda()

    # optimizer
    print('Building optimizers ...')
    encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate)
    decoder_optimizer = optim.Adam(decoder.parameters(),
                                   lr=learning_rate * decoder_learning_ratio)
    if loadFilename:
        encoder_optimizer.load_state_dict(checkpoint['en_opt'])
        decoder_optimizer.load_state_dict(checkpoint['de_opt'])

    # initialize
    print('Initializing ...')
    start_epoch = 0
    perplexity = []
    best_val_loss = None
    print_loss = 0
    if loadFilename:
        start_epoch = checkpoint['epoch'] + 1
        perplexity = checkpoint['plt']

    for epoch in range(start_epoch, n_epoch):
        epoch_start_time = time.time()
        # train epoch
        encoder.train()
        decoder.train()
        print_loss = 0
        start_time = time.time()
        for batch, training_batch in enumerate(training_batches):
            input_variable_attr, input_variable, lengths, target_variable, mask, max_target_len = training_batch

            loss = train(input_variable, lengths, target_variable, mask,
                         max_target_len, encoder, decoder, embedding,
                         encoder_optimizer, decoder_optimizer, batch_size)
            print_loss += loss
            perplexity.append(loss)
            #print("batch{} loss={}".format(batch, loss))
            if batch % print_every == 0 and batch > 0:
                cur_loss = print_loss / print_every
                elapsed = time.time() - start_time

                print(
                    '| epoch {:3d} | {:5d}/{:5d} batches | lr {:05.5f} | ms/batch {:5.2f} | '
                    'loss {:5.2f} | ppl {:8.2f}'.format(
                        epoch, batch, len(training_batches), learning_rate,
                        elapsed * 1000 / print_every, cur_loss,
                        math.exp(cur_loss)))

                print_loss = 0
                start_time = time.time()
        # evaluate
        val_loss = 0
        for val_batch in val_batches:
            input_variable_attr, input_variable, lengths, target_variable, mask, max_target_len = val_batch
            loss = evaluate(input_variable, lengths, target_variable, mask,
                            max_target_len, encoder, decoder, embedding,
                            encoder_optimizer, decoder_optimizer,
                            eval_batch_size)
            val_loss += loss
        val_loss /= len(val_batches)

        print('-' * 89)
        print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | '
              'valid ppl {:8.2f}'.format(epoch,
                                         (time.time() - epoch_start_time),
                                         val_loss, math.exp(val_loss)))
        print('-' * 89)
        # Save the model if the validation loss is the best we've seen so far.
        if not best_val_loss or val_loss < best_val_loss:
            directory = os.path.join(save_dir, 'model',
                                     '{}_{}'.format(n_layers, hidden_size))
            if not os.path.exists(directory):
                os.makedirs(directory)
            torch.save(
                {
                    'epoch': epoch,
                    'en': encoder.state_dict(),
                    'de': decoder.state_dict(),
                    'en_opt': encoder_optimizer.state_dict(),
                    'de_opt': decoder_optimizer.state_dict(),
                    'loss': loss,
                    'plt': perplexity
                },
                os.path.join(
                    directory,
                    '{}_{}.tar'.format(epoch,
                                       filename(reverse,
                                                'text_decoder_model'))))
            best_val_loss = val_loss

            # Run on test data.
            test_loss = 0
            for test_batch in test_batches:
                input_variable_attr, input_variable, lengths, target_variable, mask, max_target_len = test_batch
                loss = evaluate(input_variable, lengths, target_variable, mask,
                                max_target_len, encoder, decoder, embedding,
                                encoder_optimizer, decoder_optimizer,
                                eval_batch_size)
                test_loss += loss
            test_loss /= len(test_batches)
            print('-' * 89)
            print('| test loss {:5.2f} | test ppl {:8.2f}'.format(
                test_loss, math.exp(test_loss)))
            print('-' * 89)

        if val_loss > best_val_loss:
            break
예제 #7
0
def main(args):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    with open('data/multim_poem.json') as f, open('data/unim_poem.json') as unif:
        multim = json.load(f)
        unim = json.load(unif)

    multim = util.filter_multim(multim)
    # multim = multim[:128]
    with open('data/img_features.pkl', 'rb') as fi, open('data/poem_features.pkl', 'rb') as fp:
        img_features = pickle.load(fi)
        poem_features = pickle.load(fp)


    # make sure vocab exists
    word2idx, idx2word = util.read_vocab_pickle(args.vocab_path)

    # will be used in embedder

    if args.source == 'unim':
        data = unim
        features = poem_features
    elif args.source == 'multim':
        data = multim
        features = img_features
    else:
        print('Error: source must be unim or multim!')
        exit()

    # create data loader. the data will be in decreasing order of length
    data_loader = get_poem_poem_dataset(args.batch_size, shuffle=True,
                                        num_workers=args.num_workers, json_obj=data, features=features,
                                        max_seq_len=128, word2idx=word2idx, tokenizer=None)

    decoder = DecoderRNN(args.embed_size, args.hidden_size, len(word2idx), device)
    decoder = DataParallel(decoder)
    if args.restore:
        decoder.load_state_dict(torch.load(args.ckpt))
    if args.load:
        decoder.load_state_dict(torch.load(args.load))
    decoder.to(device)

    discriminator = Discriminator(args.embed_size, args.hidden_size, len(word2idx), num_labels=2)
    discriminator.embed.weight = decoder.module.embed.weight
    discriminator = DataParallel(discriminator)
    if args.restore:
        discriminator.load_state_dict(torch.load(args.disc))
    discriminator.to(device)

    # optimization config
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(decoder.parameters(), lr=args.learning_rate)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[3, 10], gamma=0.33)
    optimizerD = torch.optim.Adam(discriminator.parameters(), lr=args.learning_rate)

    sys.stderr.write('Start training...\n')
    total_step = len(data_loader)
    decoder.train()
    global_step = 0
    running_ls = 0
    for epoch in range(args.num_epochs):
        scheduler.step()
        acc_ls = 0
        start = time.time()

        for i, (batch) in enumerate(data_loader):
            poem_embed, ids, lengths = [t.to(device) for t in batch]
            targets = pack_padded_sequence(ids[:, 1:], lengths, batch_first=True)[0]
            # train discriminator

            # train with real
            discriminator.zero_grad()
            pred_real = discriminator(ids[:, 1:], lengths)
            real_label = torch.ones(ids.size(0), dtype=torch.long).to(device)
            loss_d_real = criterion(pred_real, real_label)
            loss_d_real.backward(torch.ones_like(loss_d_real), retain_graph=True)

            # train with fake

            logits = decoder(poem_embed, ids, lengths)
            weights = F.softmax(logits, dim=-1)
            m = Categorical(probs=weights)
            generated_ids = m.sample()

            # generated_ids = torch.argmax(logits, dim=-1)
            pred_fake = discriminator(generated_ids.detach(), lengths)
            fake_label = torch.zeros(ids.size(0)).long().to(device)
            loss_d_fake = criterion(pred_fake, fake_label)
            loss_d_fake.backward(torch.ones_like(loss_d_fake), retain_graph=True)

            loss_d = loss_d_real.mean().item() + loss_d_fake.mean().item()

            optimizerD.step()

            # train generator
            decoder.zero_grad()
            reward = F.softmax(pred_fake, dim=-1)[:, 1].unsqueeze(-1)
            loss_r = -m.log_prob(generated_ids) * reward
            loss_r.backward(torch.ones_like(loss_r), retain_graph=True)
            loss_r = loss_r.mean().item()

            loss = criterion(pack_padded_sequence(logits, lengths, batch_first=True)[0], targets)
            loss.backward(torch.ones_like(loss))
            loss = loss.mean().item()
            # loss = loss_r
            running_ls += loss
            acc_ls += loss

            for param in decoder.parameters():
                torch.nn.utils.clip_grad_norm_(param, 0.25)

            optimizer.step()
            global_step += 1

            if global_step % args.log_step == 0:
                elapsed_time = time.time() - start
                iters_per_sec = (i + 1) / elapsed_time
                remaining = (total_step - i - 1) / iters_per_sec
                remaining_fmt = time.strftime("%H:%M:%S", time.gmtime(remaining))
                elapsed_fmt = time.strftime("%H:%M:%S", time.gmtime(elapsed_time))

                print('[{}/{}, {}/{}], ls_d:{:.2f}, ls_r:{:.2f} ls: {:.2f}, Acc: {:.2f} Perp: {:5.2f} {:.3}it/s {}<{}'
                      .format(epoch+1, args.num_epochs, i+1, total_step, loss_d, loss_r,
                              running_ls / args.log_step, acc_ls / (i+1), np.exp(acc_ls / (i+1)),
                              iters_per_sec, elapsed_fmt, remaining_fmt ) )
                running_ls = 0

            if global_step % args.save_step == 0:
                torch.save(decoder.state_dict(), args.ckpt)
                torch.save(discriminator.state_dict(), args.disc)
    torch.save(decoder.state_dict(), args.save)
    torch.save(discriminator.state_dict(), args.disc)