Beispiel #1
0
def train_and_eval(hparams):
  model = get_model(hparams)
  optimizer = keras.optimizers.Adam(lr=hparams.learning_rate)
  loss_fn = keras.losses.SparseCategoricalCrossentropy()
  logger = Logger(hparams, optimizer)

  train_dataset = get_dataset(hparams, train=True)
  eval_dataset = get_dataset(hparams, train=False)

  checkpoint = Checkpoint(hparams, optimizer, model)

  checkpoint.restore()

  for epoch in range(hparams.epochs):

    start = time.time()

    for images, labels in train_dataset:
      loss, predictions = train_step(images, labels, model, optimizer, loss_fn)
      logger.log_progress(loss, labels, predictions, mode='train')

    elapse = time.time() - start

    logger.write_scalars(mode='train')

    for images, labels in eval_dataset:
      logger.write_images(images, mode='eval')
      loss, predictions = eval_step(images, labels, model, loss_fn)
      logger.log_progress(loss, labels, predictions, mode='eval')

    logger.write_scalars(mode='eval', elapse=elapse)

    logger.print_progress(epoch, elapse)

    if epoch % 5 == 0 or epoch == hparams.epochs - 1:
      checkpoint.save()

  tf.keras.models.save_model(model, filepath=hparams.save_model)
  print('model saved at %s' % hparams.save_model)
Beispiel #2
0
    def _train_epoches(self, data_loader, model, batch_size, start_epoch,
                       start_step, max_acc, n_epoch):
        train_list = data_loader.train_list
        test_list = data_loader.test_list

        step = start_step
        print_loss_total = 0
        max_ans_acc = max_acc

        for epoch_index, epoch in enumerate(range(start_epoch, n_epoch + 1)):
            model.train()
            batch_generator = data_loader.get_batch(train_list,
                                                    batch_size,
                                                    template_flag=True)
            for batch_data_dict in batch_generator:
                step += 1
                input_variables = batch_data_dict['batch_span_encode_idx']
                input_lengths = batch_data_dict['batch_span_encode_len']
                span_length = batch_data_dict['batch_span_len']
                tree = batch_data_dict["batch_tree"]

                input_variables = [
                    torch.LongTensor(input_variable)
                    for input_variable in input_variables
                ]
                input_lengths = [
                    torch.LongTensor(input_length)
                    for input_length in input_lengths
                ]
                span_length = torch.LongTensor(span_length)
                if self.use_cuda:
                    input_variables = [
                        input_variable.cuda()
                        for input_variable in input_variables
                    ]
                    input_lengths = [
                        input_length.cuda() for input_length in input_lengths
                    ]
                    span_length = span_length.cuda()

                span_num_pos = batch_data_dict["batch_span_num_pos"]
                word_num_poses = batch_data_dict["batch_word_num_poses"]
                span_num_pos = torch.LongTensor(span_num_pos)
                word_num_poses = [
                    torch.LongTensor(word_num_pos)
                    for word_num_pos in word_num_poses
                ]
                if self.use_cuda:
                    span_num_pos = span_num_pos.cuda()
                    word_num_poses = [
                        word_num_pose.cuda()
                        for word_num_pose in word_num_poses
                    ]
                num_pos = (span_num_pos, word_num_poses)

                target_variables = batch_data_dict['batch_decode_idx']
                target_variables = torch.LongTensor(target_variables)
                if self.use_cuda:
                    target_variables = target_variables.cuda()

                loss = self._train_batch(input_variables=input_variables,
                                         num_pos=num_pos,
                                         input_lengths=input_lengths,
                                         span_length=span_length,
                                         target_variables=target_variables,
                                         tree=tree,
                                         model=model,
                                         batch_size=batch_size)

                print_loss_total += loss
                if step % self.print_every == 0:
                    print_loss_avg = print_loss_total / self.print_every
                    print_loss_total = 0
                    logging.info(
                        f'step: {step}, Train loss: {print_loss_avg:.4f}')
                    if self.use_cuda:
                        torch.cuda.empty_cache()
            self.scheduler.step()

            model.eval()
            with torch.no_grad():
                test_temp_acc, test_ans_acc = self.evaluator.evaluate(
                    model=model,
                    data_loader=data_loader,
                    data_list=test_list,
                    template_flag=True,
                    template_len=True,
                    batch_size=batch_size,
                )
                if epoch_index % self.test_train_every == 0:
                    train_temp_acc, train_ans_acc = self.evaluator.evaluate(
                        model=model,
                        data_loader=data_loader,
                        data_list=train_list,
                        template_flag=True,
                        template_len=True,
                        batch_size=batch_size,
                    )

                    logging.info(
                        f"Epoch: {epoch}, Step: {step}, test_acc: {test_temp_acc:.3f}, {test_ans_acc:.3f}, train_acc: {train_temp_acc:.3f}, {train_ans_acc:.3f}"
                    )
                else:
                    logging.info(
                        f"Epoch: {epoch}, Step: {step}, test_acc: {test_temp_acc:.3f}, {test_ans_acc:.3f}"
                    )

            if test_ans_acc > max_ans_acc:
                max_ans_acc = test_ans_acc
                logging.info("saving checkpoint ...")
                Checkpoint.save(epoch=epoch,
                                step=step,
                                max_acc=max_ans_acc,
                                model=model,
                                optimizer=self.optimizer,
                                scheduler=self.scheduler,
                                best=True)
            else:
                Checkpoint.save(epoch=epoch,
                                step=step,
                                max_acc=max_ans_acc,
                                model=model,
                                optimizer=self.optimizer,
                                scheduler=self.scheduler,
                                best=False)
        return