Пример #1
0
def evaluate():
    """
    Evaluate the model on the validation set.
    """
    val_loss.reset_states()
    for inp, tar in val_dataset:
        tar_inp = tar[:, :-1]
        tar_real = tar[:, 1:]

        mask = masks.create_look_ahead_mask(tf.shape(tar_inp)[1])

        predictions, _ = model(inp,
                               tar_inp,
                               training=False,
                               look_ahead_mask=mask,
                               dec_padding_mask=None)

        loss = loss_function(tar_real, predictions, val_loss_object)
        val_loss(loss)
        val_accuracy(tar_real, predictions)

        global val_bleu
        val_bleu = bleu_score(predictions=predictions, labels=tar_real)

    ckpt_manager.save()
Пример #2
0
def train_step(inp, tar, evaluate_step):
    """
    A training step on a batch.
    """
    tar_inp = tar[:, :-1]
    tar_real = tar[:, 1:]

    mask = masks.create_look_ahead_mask(tf.shape(tar_inp)[1])

    with tf.GradientTape() as tape:
        predictions, _ = model(inp,
                               tar_inp,
                               True,
                               look_ahead_mask=mask,
                               dec_padding_mask=None)
        loss = loss_function(tar_real, predictions, train_loss_object)

    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

    global train_bleu
    if evaluate_step:
        train_loss(loss)
        train_accuracy(tar_real, predictions)
        train_bleu = bleu_score(predictions=predictions, labels=tar_real)
 def score_helper(image_path):
     caption = generate_caption(image_path, vocab, encoder, decoder,
                                transform)
     generated_score, theoratical_score = bleu_score(
         image_path, caption, name_caption_frame)
     total_generated_score.append(generated_score)
     total_theoratical_score.append(theoratical_score)
     print(generated_score, theoratical_score)
Пример #4
0
  def train_epochs(self, encoder, decoder, train_data, dev_data=None, \
                   pred_=None):

    encoder.to(self.device)
    decoder.to(self.device)

    train_dataloader = DataLoader(train_data, self.batch_size, shuffle=True)

    encoder_optimizer = optim.Adam(filter(lambda p: p.requires_grad, \
                                   encoder.parameters()), lr=self.learning_rate)
    decoder_optimizer = optim.Adam(filter(lambda p: p.requires_grad, \
                                   decoder.parameters()), lr=self.learning_rate)
    checkpoint_epoch=0
    total_epoch = 0
    train_score=[0]
    val_score=[0]

    if self.snapshot != None:
      print('loading model from : {}'.format(self.snapshot))
      checkpoint = torch.load(self.snapshot)
      total_epoch = checkpoint['epochs']
      encoder.load_state_dict(checkpoint['encoder'])
      decoder.load_state_dict(checkpoint['decoder'])
      encoder_optimizer.load_state_dict(checkpoint['encoder_optimizer'])
      decoder_optimizer.load_state_dict(checkpoint['decoder_optimizer'])
    
    t = tqdm(range(1, self.epochs+1))
    try:
        for epoch in t:

            loss = self.train_batches(encoder, encoder_optimizer, decoder, \
                                      decoder_optimizer, train_dataloader)
            if (epoch) % self.eval_frequency == 0:
              checkpoint_epoch += self.eval_frequency
              train_predict = pred_.predict_dataset(encoder, decoder, train_data)
              if dev_data is None:
                t.set_description("loss %s" % loss)
                print('Please provide a dev set for visualize the plot')
              else:
                dev_predict = pred_.predict_dataset(encoder, decoder, dev_data)
                bleu = bleu_score(dev_predict[1], dev_predict[0])
                t.set_description("loss {}, bleu {}".format(loss, bleu))
              
              path = self.path + "/model" + str(total_epoch+epoch) +".pt"
              save_model(checkpoint_epoch+total_epoch, encoder, decoder, encoder_optimizer, decoder_optimizer, path)
    except KeyboardInterrupt:
        # Code to "save"
        print('save model')
        save_model(checkpoint_epoch+total_epoch, encoder, decoder, encoder_optimizer, decoder_optimizer, path)
Пример #5
0
 def on_batch_begin(self, last_input, last_target, train, **kwargs):
     """
     Calculates output sentence using beam search for every batch of
     validation examples.
     """
     if not train:
         batch_size = last_input.size(0)
         for in_data in last_input.split(min(batch_size, 16), dim=0):
             src_data, tgt_data = in_data.split([self.Ts, self.Tt], dim=1)
             out_data = beam_search(self.learn.model, src_data,
                                    self.beam_size, self.Tt - 1)
             bleu = 0.0
             for b in range(out_data.shape[0]):
                 out_len = out_data.shape[1]
                 tgt_len = tgt_data.shape[1]
                 out_l = []
                 for i in range(out_data.shape[1]):
                     if (out_data[b][i].item() == self.eos
                             or out_data[b][i].item() == self.pad):
                         out_len = i
                         break
                     out_l.append(str(out_data[b][i].item()))
                 tgt_l = []
                 for i in range(1, self.Tt):
                     if (tgt_data[b][i].item() == self.eos
                             or tgt_data[b][i].item() == self.pad):
                         tgt_len = i
                         # The Moses BLEU score script gives 0 for sentences
                         # of length less than 4, so ignore those BLEU score.
                         if i < 4:
                             batch_size -= 1
                         break
                     tgt_l.append(str(tgt_data[b][i].item()))
                 if out_len >= 4 and tgt_len >= 4:
                     self.count += 1
                     bleu += bleu_score([' '.join(out_l)],
                                        [[' '.join(tgt_l)]]) * 100
         self.bleu += bleu
Пример #6
0
def main(args):
    # Folder setting
    if not os.path.isdir('./model'):
        os.mkdir('./model')
    if not os.path.isdir('./result'):
        os.mkdir('./result')

    # Hyparameters setting
    epochs = args.epochs
    batch_size = args.batch_size
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    dataset = DataSetWrapper(batch_size,
                             args.num_workers,
                             args.valid_size,
                             input_shape=(224, 224, 3))
    train_loader, valid_loader = dataset.get_data_loaders()
    vocab = dataset.dataset.vocab

    model = make_model(len(vocab), 512, vocab.stoi['<PAD>'],
                       vocab.stoi['<SOS>'], vocab.stoi['<EOS>'],
                       device).to(device)
    optimizer = torch.optim.RMSprop(params=model.parameters(), lr=1e-4)
    criterion = nn.CrossEntropyLoss().to(device)

    # torch.set_printoptions(profile='full')
    # torch.autograd.set_detect_anomaly(True)
    print("Using device: " + str(device))

    if args.prev_model != '':
        checkpoint = torch.load('./model/' + args.prev_model,
                                map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state'])
        # model.eval()

    if not args.inference:
        # Train
        best_bleu = 0
        train_loss_list, valid_loss_list = [], []
        bleu_list = []

        print('Start training')
        for epoch in range(epochs):
            train_total_loss, valid_total_loss = 0.0, 0.0
            model.train()
            for x, tgt, tgt_len, _ in train_loader:
                # tgt: (batch_size, tgt_len)
                # (batch_size, tgt_len, vocab_size), (batch_size, tgt_len, num pixels)
                x = x.to(device)
                tgt = tgt.to(device)
                tgt_len = tgt_len.to(device)
                predictions, tgt_len, alpha = model(x, tgt, tgt_len, True)

                predictions = predictions[:, 1:]
                packed_predictions = pack_padded_sequence(
                    predictions, tgt_len, batch_first=True)[0].to(device)
                tgt = tgt[:, 1:]
                packed_tgt = pack_padded_sequence(
                    tgt, tgt_len, batch_first=True)[0].to(device)

                optimizer.zero_grad()
                loss = criterion(packed_predictions, packed_tgt)
                loss += ((1 - alpha.sum(dim=1))**2).mean()
                loss.backward()
                optimizer.step()

                _, predictions = torch.max(predictions, dim=2)

                train_total_loss += loss

            hypotheses = []
            references = []
            model.eval()

            with torch.no_grad():
                for x, tgt, tgt_len, all_tgt in valid_loader:
                    x = x.to(device)
                    tgt = tgt.to(device)
                    tgt_len = tgt_len.to(device)
                    predictions, tgt_len, alpha = model(x, tgt, tgt_len, False)

                    predictions = predictions[:, 1:]
                    packed_predictions = pack_padded_sequence(
                        predictions, tgt_len, batch_first=True)[0].to(device)
                    tgt = tgt[:, 1:]
                    packed_tgt = pack_padded_sequence(
                        tgt, tgt_len, batch_first=True)[0].to(device)

                    loss = criterion(packed_predictions, packed_tgt)
                    loss += ((1 - alpha.sum(dim=1))**2).mean()

                    valid_total_loss += loss

                    _, predictions = torch.max(predictions, dim=2)

                    # Calculate BLEU
                    # TODO: Collect all reference captions, not one
                    predictions = predictions.cpu().tolist()
                    all_tgt = all_tgt.tolist()
                    t_prediciotns = []
                    t_tgt = []
                    for i in range(len(tgt)):
                        t_tgt.append(all_tgt[i])
                        t_prediciotns.append(predictions[i][:tgt_len[i] - 1])
                    predictions = t_prediciotns
                    tgt = t_tgt

                    hypotheses.extend(predictions)
                    references.extend(tgt)
                    assert len(references) == len(hypotheses)

            bleus = bleu_score(hypotheses, references)

            bleu_list.append(bleus[0])
            train_loss_list.append(train_total_loss)
            valid_loss_list.append(valid_loss_list)
            if (not args.no_save) and best_bleu <= bleus[0]:
                best_bleu = bleus[0]
                torch.save(
                    {
                        'epoch': epoch,
                        'model_state_dict': model.state_dict(),
                        'optimizer_state': optimizer.state_dict(),
                        'bleu': best_bleu
                    }, './model/' + str(epochs) + ".pt")
                print('Model saved')

            print(
                time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + "|| [" +
                str(epoch) + "/" + str(epochs) + "], train_loss = " +
                str(train_total_loss) + ", valid_loss = " +
                str(valid_total_loss) + ", BLEU = " + str(bleus[4]) + ', ' +
                str(bleus[0]) + '/' + str(bleus[1]) + '/' + str(bleus[2]) +
                '/' + str(bleus[3]))

        plot_loss_curve(train_loss_list, valid_loss_list, bleus)
    else:
        print("Start Inference")
        """
        Show image with caption
        ref: https://github.com/AaronCCWong/Show-Attend-and-Tell
        """

        # Load model
        assert args.prev_model != '' and args.img_path != ''
        checkpoint = torch.load('./model/' + args.prev_model,
                                map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state'])
        model.eval()

        img = Image.open(args.img_path).convert('RGB')
        img = dataset.dataset.transform(img)
        img = torch.FloatTensor(img)
        img = img.unsqueeze(0)

        img_feature = model.encode(img)
        img_feature = img_feature.view(img_feature.shape[0], -1,
                                       img_feature.shape[-1])
        tgt_len = [torch.tensor(args.max_length)]
        tgt = torch.zeros(1, args.max_length)
        sentence, tgt_len, alpha = model.decode(img_feature, tgt, tgt_len,
                                                False, True)
        _, sentence = torch.max(sentence, dim=2)
        sentence = sentence.squeeze()
        sentence = sentence[:tgt_len[0]]

        sentence = vocab.interpret(sentence.tolist())
        print(sentence)

        img = Image.open(args.img_path)
        w, h = img.size
        if w > h:
            w = w * 256 / h
            h = 256
        else:
            h = h * 256 / w
            w = 256
        left = (w - 224) / 2
        top = (h - 224) / 2
        resized_img = img.resize((int(w), int(h)), Image.BICUBIC).crop(
            (left, top, left + 224, top + 224))
        img = np.array(resized_img.convert('RGB').getdata()).reshape(
            224, 224, 3)
        img = img.astype('float32') / 255

        num_words = len(sentence)
        w = np.round(np.sqrt(num_words))
        h = np.ceil(np.float32(num_words) / w)
        alpha = alpha.clone().detach().squeeze()

        plot_height = np.ceil((num_words + 3) / 4.0)
        ax1 = plt.subplot(4, plot_height, 1)
        plt.imshow(img)
        plt.axis('off')
        for idx in range(num_words):
            ax2 = plt.subplot(4, plot_height, idx + 2)
            label = sentence[idx]
            plt.text(0, 1, label, backgroundcolor='white', fontsize=13)
            plt.text(0, 1, label, color='black', fontsize=13)
            plt.imshow(img)

            shape_size = 14
            alpha_img = skimage.transform.pyramid_expand(alpha[idx, :].reshape(
                shape_size, shape_size),
                                                         upscale=16,
                                                         sigma=20)
            plt.imshow(alpha_img, alpha=0.8)
            plt.set_cmap(cm.Greys_r)
            plt.axis('off')

        plt.show()
Пример #7
0
workspace_path = lambda file_path: os.path.join(data_root_dir, file_path)
paras_file, titles_file, embedding_file = workspace_path(paras_file), \
  workspace_path(titles_file), workspace_path(embedding_file)

input_data, inputs_for_tf, input_placeholders = input_generator.get_inputs(
    paras_file, titles_file, embedding_file, FLAGS)

###########
## Model ##
###########
outputs = RNNModel(inputs_for_tf,
                   FLAGS,
                   is_training=IS_TRAINING,
                   multirnn=False)
blue_score = bleu.bleu_score(predictions=outputs.sample_id,
                             labels=inputs_for_tf['title_batch'][:, 1:])

##############
## Training ##
##############
sess = tf.Session()
tf.tables_initializer().run(session=sess)
sess.run(tf.global_variables_initializer())
sess.run(
    input_placeholders['embedding_init'],
    feed_dict={input_placeholders['embedding_ph']: input_data['embedding']})
saver = tf.train.import_meta_graph('checkpoints.meta')
saver.restore(sess, tf.train.latest_checkpoint('.'))

sess.run(inputs_for_tf['iterator'].initializer,
         feed_dict={
Пример #8
0
    def model_fn(features, labels, mode):
        with tf.variable_scope("transformer_vqvae",
                               initializer=tf.variance_scaling_initializer(
                                   hparams.initializer_gain,
                                   mode="fan_avg",
                                   distribution="uniform")):

            if mode != tf.estimator.ModeKeys.TRAIN:
                for key in hparams.keys():
                    if key.endswith("dropout"):
                        setattr(hparams, key, 0.0)

            with tf.variable_scope("embeddings",
                                   initializer=tf.random_normal_initializer(
                                       0.0, hparams.hidden_size**-0.5)):
                source_embeddings = tf.get_variable(
                    "source_embeddings",
                    [len(hparams.source_vocab), hparams.hidden_size],
                    tf.float32)
                if hparams.shared_embedding:
                    target_embeddings = source_embeddings
                else:
                    target_embeddings = tf.get_variable(
                        "target_embeddings",
                        [len(hparams.target_vocab), hparams.hidden_size],
                        tf.float32)

            encoder_input_layer = commons.input_layer(source_embeddings,
                                                      hparams)
            decoder_input_layer = commons.input_layer(target_embeddings,
                                                      hparams)
            output_layer = tf.layers.Dense(len(hparams.target_vocab),
                                           use_bias=False,
                                           name="output")

            # create model
            x_enc = encoder_input_layer(features["sources"])
            model = TransformerNAT(hparams, mode)

            # decode
            if mode != tf.estimator.ModeKeys.PREDICT:
                x_dec = decoder_input_layer(features["targets"])

                decoder_outputs, losses = model.body(features={
                    'inputs': x_enc,
                    'targets': x_dec
                })
                logits = output_layer(decoder_outputs)
                predictions = tf.argmax(logits, -1)
                tgt_len = commons.shape_list(features["targets"])[1]
                losses["cross_entropy"] = commons.compute_loss(
                    logits[:, :tgt_len], features["targets"])

                # losses
                loss = 0.
                for k, l in losses.items():
                    tf.summary.scalar(k, l)
                    loss += l

            else:
                decoder_outputs = model.infer(features={'inputs': x_enc})
                logits = output_layer(decoder_outputs)
                predictions = tf.argmax(logits, -1)
                loss = None

        if mode == tf.estimator.ModeKeys.TRAIN:
            train_op = commons.get_train_op(loss, hparams)
        else:
            train_op = None

        if mode == tf.estimator.ModeKeys.EVAL:
            # Names tensors to use in the printing tensor hook.
            targets = tf.identity(features["targets"], "targets")
            predictions = tf.identity(predictions, "predictions")

            bleu_score = bleu.bleu_score(predictions, targets)
            eval_metrics = {
                "metrics/approx_bleu_score": tf.metrics.mean(bleu_score)
            }

            # Summaries
            eval_summary_hook = tf.train.SummarySaverHook(
                save_steps=1,
                output_dir=os.path.join(hparams.model_dir, "eval"),
                summary_op=tf.summary.merge_all())
            eval_summary_hooks = [eval_summary_hook]
        else:
            eval_metrics = None
            eval_summary_hooks = None

        return tf.estimator.EstimatorSpec(mode,
                                          predictions=predictions,
                                          loss=loss,
                                          eval_metric_ops=eval_metrics,
                                          evaluation_hooks=eval_summary_hooks,
                                          train_op=train_op)