コード例 #1
0
def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    input_lang, output_lang, pairs = prepareData('eng',
                                                 'fra',
                                                 True,
                                                 dir='data',
                                                 filter=False)
    hidden_size = 512
    batch_size = 64
    iters = 50000
    # encoder = EncoderRNN(input_lang.n_words, hidden_size).to(device)
    encoder = EncoderRNN(input_lang.n_words, hidden_size)
    attn_decoder = AttnDecoderRNN(hidden_size,
                                  output_lang.n_words,
                                  dropout_p=0.1)
    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        encoder = nn.DataParallel(encoder)
        attn_decoder = nn.DataParallel(attn_decoder)
    encoder = encoder.to(device)
    attn_decoder = attn_decoder.to(device)

    # attn_decoder = AttnDecoderRNN(hidden_size, output_lang.n_words, dropout_p=0.1).to(device)
    trainIters(device,
               pairs,
               input_lang,
               output_lang,
               encoder,
               attn_decoder,
               batch_size,
               iters,
               print_every=250)
コード例 #2
0
def main(args):
    global best_bleu4, epochs_since_improvement, checkpoint, start_epoch, fine_tune_encoder, data_name, word_map

    # Load vocabulary wrapper
    with open(args.vocab_path, 'rb') as f:
        vocab = pickle.load(f)

    if args.checkpoint is None:
        decoder = AttnDecoderRNN(attention_dim=args.attention_dim,
                                 embed_dim=args.embed_dim,
                                 decoder_dim=args.decoder_dim,
                                 vocab_size=len(vocab),
                                 dropout=args.dropout)
        decoder_optimizer = torch.optim.Adam(params=filter(
            lambda p: p.requires_grad, decoder.parameters()),
                                             lr=args.decoder_lr)
        encoder = EncoderCNN()
        encoder.fine_tune(args.fine_tune_encoder)
        encoder_optimizer = torch.optim.Adam(
            params=filter(lambda p: p.requires_grad, encoder.parameters()),
            lr=args.encoder_lr) if args.fine_tune_encoder else None
    else:
        checkpoint = torch.load(args.checkpoint)
        start_epoch = checkpoint['epoch'] + 1
        epochs_since_improvement = checkpoint['epochs_since_improvement']
        best_bleu4 = checkpoint['bleu-4']
        decoder = checkpoint['decoder']
        decoder_optimizer = checkpoint['decoder_optimizer']
        encoder = checkpoint['encoder']
        encoder_optimizer = checkpoint['encoder_optimizer']
        if fine_tune_encoder is True and encoder_optimizer is None:
            encoder.fine_tune(fine_tune_encoder)
            encoder_optimizer = torch.optim.Adam(params=filter(
                lambda p: p.requires_grad, encoder.parameters()),
                                                 lr=args.encoder_lr)
    decoder = decoder.to(device)
    encoder = encoder.to(device)

    criterion = nn.CrossEntropyLoss().to(device)

    # Image preprocessing, normalization for the pretrained resnet
    transform = transforms.Compose([
        transforms.RandomCrop(args.crop_size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])

    # Build data loader
    train_loader = get_loader(args.image_dir,
                              args.caption_path,
                              vocab,
                              transform,
                              args.batch_size,
                              shuffle=True,
                              num_workers=args.num_workers)

    val_loader = get_loader(args.image_dir_val,
                            args.caption_path_val,
                            vocab,
                            transform,
                            args.batch_size,
                            shuffle=True,
                            num_workers=args.num_workers)

    for epoch in range(args.start_epoch, args.epochs):
        if args.epochs_since_improvement == 20:
            break
        if args.epochs_since_improvement > 0 and args.epochs_since_improvement % 8 == 0:
            adjust_learning_rate(decoder_optimizer, 0.8)
            if args.fine_tune_encoder:
                adjust_learning_rate(encoder_optimizer, 0.8)

        train(train_loader=train_loader,
              encoder=encoder,
              decoder=decoder,
              criterion=criterion,
              encoder_optimizer=encoder_optimizer,
              decoder_optimizer=decoder_optimizer,
              epoch=epoch)

        recent_bleu4 = validate(val_loader=val_loader,
                                encoder=encoder,
                                decoder=decoder,
                                criterion=criterion)

        is_best = recent_bleu4 > best_bleu4
        best_bleu4 = max(recent_bleu4, best_bleu4)
        if not is_best:
            args.epochs_since_improvement += 1
            print("\nEpoch since last improvement: %d\n" %
                  (args.epochs_since_improvement, ))
        else:
            args.epochs_since_improvement = 0

        save_checkpoint(args.data_name, epoch, args.epochs_since_improvement,
                        encoder, decoder, encoder_optimizer, decoder_optimizer,
                        recent_bleu4, is_best)
コード例 #3
0
class Seq2Pose():
    def __init__(self, wm, input_length, batch_size, hidden_size, bidirectional\
            , embedding_size, n_parameter, m_parameter, learning_rate, clip,\
                alpha, beta, pre_trained_file = None):
        self.batch_size = batch_size
        self.hidden_size = hidden_size
        self.embedding_size = embedding_size
        self.bidirectional = bidirectional
        self.n_parameter = n_parameter
        self.m_parameter = m_parameter
        self.learning_rate = learning_rate
        self.wm = wm
        self.clip = clip
        self.alpha = alpha
        self.beta = beta
        if pre_trained_file == None:
            self.encoder = EncoderRNN(self.wm, self.embedding_size,\
                hidden_size, bidirectional)
            self.decoder = AttnDecoderRNN(self.hidden_size, 10)
            self.enc_optimizer = optim.Adam(self.encoder.parameters(),\
                lr=self.learning_rate)
            self.dec_optimizer = optim.Adam(self.decoder.parameters(),\
                lr=self.learning_rate)
            self.start = 0
        else:
            self.resume_training = True
            self.encoder, self.decoder, self.enc_optimizer, self.dec_optimizer,\
                self.start = self.load_model_state(pre_trained_file)
        self.decoder = self.decoder.to(device)
        self.encoder = self.encoder.to(device)

    def load_model_state(self, model_file):
        print("Resuming training from a given model...")
        model = torch.load(model_file,
                           map_location=lambda storage, loc: storage)
        epoch = model['epoch']
        encoder_state_dict = model['encoder_state_dict']
        encoder_optimizer_state_dict = model['encoder_optimizer_state_dict']
        decoder_state_dict = model['decoder_state_dict']
        decoder_optimizer_state_dict = model['decoder_optimizer_state_dict']
        loss = model['loss']
        encoder = EncoderRNN(self.wm, self.embedding_size,\
            self.hidden_size, self.bidirectional)
        decoder = AttnDecoderRNN(self.hidden_size, 10)
        enc_optimizer = optim.Adam(encoder.parameters(), lr=self.learning_rate)
        dec_optimizer = optim.Adam(decoder.parameters(), lr=self.learning_rate)
        return encoder, decoder, enc_optimizer, dec_optimizer, epoch

    def train(self, epochs, x_train, y_train):
        """
        Training loop, trains the network for the given parameters.

        Keyword arguments:
        epochs - number of epochs to train for (looping over the whole dataset)
        x_train - training data, contains a list of integer encoded strings
        y_train - training data, contains a list of pose sequences
        """
        criterion = CustomLoss(self.alpha, self.beta)
        training_set = Dataset(x_train, y_train)
        training_generator = data.DataLoader(training_set,\
            batch_size=self.batch_size, shuffle=True,\
            collate_fn=self.pad_and_sort_batch,\
            num_workers=8, drop_last=True)
        decoder_fixed_previous = Variable(torch.zeros(self.n_parameter,\
            self.batch_size, 10, requires_grad=False)).to(device)
        decoder_fixed_input = torch.FloatTensor\
            ([[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]] *\
                self.batch_size).to(device)

        for epoch in range(self.start, epochs):
            total_loss = 0
            for mini_batches, max_target_length in tqdm(training_generator):
                #kickstart vectors
                self.enc_optimizer.zero_grad()
                self.dec_optimizer.zero_grad()
                loss = 0
                decoder_previous_inputs = decoder_fixed_previous
                for z in range(self.n_parameter):
                    decoder_previous_inputs[z] = decoder_fixed_input
                for i, (x, y, lengths) in enumerate(mini_batches):
                    t1 = time.perf_counter()
                    x = x.to(device)
                    y = y.to(device)
                    decoder_m = np.shape(y)[0]
                    encoder_outputs, encoder_hidden = self.encoder(x, None)
                    decoder_hidden = encoder_hidden[:self.decoder.n_layers]
                    decoder_output = None
                    for n_prev in range(self.n_parameter):
                        decoder_output, decoder_hidden, attn_weights =\
                            self.decoder(decoder_previous_inputs[n_prev].float(),\
                                decoder_hidden, encoder_outputs)
                    decoder_input = decoder_output.float()
                    decoder_previous_generated = Variable(torch.zeros(decoder_m,\
                        self.batch_size, 10, requires_grad=False)).to(device)
                    decoder_outputs_generated = Variable(torch.zeros(decoder_m,\
                        self.batch_size, 10, requires_grad=False)).to(device)
                    for fut_pose in range(decoder_m):
                        decoder_output, decoder_hidden, attn_weights =\
                            self.decoder(decoder_input,decoder_hidden, encoder_outputs)
                        decoder_outputs_generated[fut_pose] = decoder_output
                        decoder_input = y[fut_pose].float()
                    decoder_previous_inputs = decoder_outputs_generated[:-10]
                    # max_length, batch_, item
                    # now mask generated outputs
                    decoder_masked = torch.where(y == 0.0, y.float(),\
                        decoder_outputs_generated.float())
                    decoder_previous_generated[1:] = decoder_masked[:-1]
                    loss += criterion(decoder_masked, decoder_previous_generated,\
                        y.float())
                    total_loss += loss.item()

                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.encoder.parameters(),\
                        self.clip)
                torch.nn.utils.clip_grad_norm_(self.decoder.parameters(),\
                        self.clip)
                self.enc_optimizer.step()
                self.dec_optimizer.step()

            if epoch % 10 == 0:
                self.save_model(self.encoder, self.decoder, self.enc_optimizer,\
                    self.dec_optimizer, epoch, "./models/seq2seq_{}_{}.tar".\
                    format(epoch, total_loss/len(x_train)), total_loss)
            print("Epoch: {} Loss: {}".format(epoch, total_loss))

    def pad_and_sort_batch(self, DataLoaderBatch):
        """
        Pads and sorts the batches, provided as a collate function.

        Keyword arguments:
        DataLoaderBatch - Batch of data coming from dataloader class.
        """
        batch_size = len(DataLoaderBatch)
        batch_split = list(zip(*DataLoaderBatch))

        seqs, targs, lengths, target_lengths = batch_split[0], batch_split[1],\
            batch_split[2], batch_split[3]

        #calculating the size for the minibatches
        max_length = max(lengths)  #longest sequence in X
        max_target_length = max(target_lengths)  #longest sequence in Y
        number_of_chunks = int(max_target_length / self.m_parameter)
        not_in_chunk = max_target_length % self.m_parameter
        words_per_chunk = int(max_length / number_of_chunks)
        not_in_words_per_chunk = max_length % words_per_chunk

        #first zeropad it all
        padded_seqs = np.zeros((batch_size, max_length))
        for i, l in enumerate(lengths):
            padded_seqs[i, 0:l] = seqs[i][0:l]
        new_targets = np.zeros((batch_size, max([len(s) for s in targs]), 10))
        for i, item in enumerate(targs):
            new_targets[i][:len(targs[i])] = targs[i]
        seq_lengths, perm_idx = torch.tensor(lengths).sort(descending=True)
        seq_lengths = list(seq_lengths)
        seq_tensor = padded_seqs[perm_idx]
        target_tensor = new_targets[perm_idx]
        #Full batch is sorted, now we are going to create minibatches.
        #in these batches time comes first, so: [time, batch, features]
        #we also add a vector with lengths, which are necessary for padding
        mini_batches = []  #contains x and y tensor per item
        seq_tensor = np.transpose(seq_tensor, (1, 0))
        target_tensor = np.transpose(target_tensor, (1, 0, 2))
        counter = 0
        for i in range(number_of_chunks):
            x = seq_tensor[i * words_per_chunk:(i + 1) * words_per_chunk]
            y = target_tensor[i * self.m_parameter:(i + 1) * self.m_parameter]
            counter += words_per_chunk * i
            x_mini_batch_lengths = []
            for j in range(batch_size):
                if seq_lengths[j] > counter and seq_lengths[
                        j] < counter + words_per_chunk:
                    x_mini_batch_lengths.append(seq_lengths[j].item() -
                                                counter)
                elif seq_lengths[j] > counter + words_per_chunk:
                    x_mini_batch_lengths.append(words_per_chunk)
                else:
                    x_mini_batch_lengths.append(0)
            mini_batches.append([
                torch.tensor(x).long(),
                torch.tensor(y), x_mini_batch_lengths
            ])
        if not_in_chunk != 0:
            x = seq_tensor[number_of_chunks * words_per_chunk:]
            y = target_tensor[number_of_chunks * self.m_parameter:]
            x_mini_batch_lengths = []
            counter = number_of_chunks * words_per_chunk
            for j in range(batch_size):
                if seq_lengths[j] > counter and seq_lengths[
                        j] < counter + words_per_chunk:
                    x_mini_batch_lengths.append(seq_lengths[j].item() -
                                                counter)
                elif seq_lengths[j] > counter + words_per_chunk:
                    x_mini_batch_lengths.append(words_per_chunk)
                else:
                    x_mini_batch_lengths.append(0)
            if len(x) > 0 and len(y) > 0:
                mini_batches.append([
                    torch.tensor(x).long(),
                    torch.tensor(y), x_mini_batch_lengths
                ])
        return mini_batches, max_target_length


    def save_model(self, encoder, decoder, enc_optimizer, dec_optimizer,\
        epoch, PATH, loss):
        torch.save(
            {
                'epoch': epoch,
                'encoder_state_dict': encoder.state_dict(),
                'encoder_optimizer_state_dict': enc_optimizer.state_dict(),
                'decoder_state_dict': decoder.state_dict(),
                'decoder_optimizer_state_dict': dec_optimizer.state_dict(),
                'loss': loss,
            }, PATH)