def main():
    parser = argparse.ArgumentParser(description='gpat train ')
    parser.add_argument("out")
    parser.add_argument('--resume', default=None)
    parser.add_argument('--log_dir', default='runs_16')
    parser.add_argument('--gpus',
                        '-g',
                        type=int,
                        nargs="*",
                        default=[0, 1, 2, 3])
    parser.add_argument('--iterations',
                        default=10**5,
                        type=int,
                        help='number of iterations to learn')
    parser.add_argument('--interval',
                        default=1000,
                        type=int,
                        help='number of iterations to evaluate')
    parser.add_argument('--batch_size',
                        '-b',
                        type=int,
                        default=128,
                        help='learning minibatch size')
    parser.add_argument('--lr', type=float, default=1e-4)
    parser.add_argument('--loaderjob', type=int, default=8)
    parser.add_argument('--hed',
                        dest='hed',
                        action='store_true',
                        default=False)
    # parser.add_argument('--size', '-s', default=96, type=int, choices=[48, 64, 80, 96, 112, 128],
    #                     help='image size')
    parser.add_argument('--no-texture',
                        dest='texture',
                        action='store_false',
                        default=True)
    parser.add_argument('--cbp',
                        dest='cbp',
                        action='store_true',
                        default=False)
    parser.add_argument('--no-color_aug',
                        dest='color_aug',
                        action='store_false',
                        default=True)
    parser.add_argument('--model_test', default='', type=str)
    parser.add_argument('--no-finetune',
                        dest='finetune',
                        action='store_false',
                        default=True)
    parser.add_argument('--arch',
                        default='googlenet',
                        choices=[
                            'texturecnn', 'resnet50', 'resnet101', 'googlenet',
                            'vgg', 'alex', 'trained', 'resume'
                        ])
    parser.add_argument('--opt', default='adam', choices=['adam', 'momentum'])
    parser.add_argument('--train_path', default='train_extracted_dataset.pkl')
    parser.add_argument('--test_path', default='test_extracted_dataset.pkl')
    parser.add_argument('--data_size', type=float, default=1.)
    parser.add_argument('--new', action='store_true', default=False)
    args = parser.parse_args()

    devices = tuple(args.gpus)
    # os.environ['PATH'] += ':/usr/local/cuda/bin'

    # log directory
    logger.init(args)

    # load data
    train_dataset = np.load(os.path.join(dataset_path, args.train_path))
    test_dataset = np.load(os.path.join(dataset_path, args.test_path))
    num_class = 2
    image_size = 256
    crop_size = 224

    if 'extracted' in train_dataset:
        perm = np.random.permutation(
            len(train_dataset))[:int(len(train_dataset) * args.data_size)]
        train_dataset = [train_dataset[p] for p in perm]

    preprocess_type = args.arch if not args.hed else 'hed'
    if 'extracted' in args.train_path:
        train = CamelyonDatasetEx(train_dataset,
                                  original_size=image_size,
                                  crop_size=crop_size,
                                  aug=True,
                                  color_aug=args.color_aug,
                                  preprocess_type=preprocess_type)
    else:
        train = CamelyonDatasetFromTif(train_dataset,
                                       original_size=image_size,
                                       crop_size=crop_size,
                                       aug=True,
                                       color_aug=args.color_aug,
                                       preprocess_type=preprocess_type)
    if len(devices) > 1:
        train_iter = [
            chainer.iterators.MultiprocessIterator(i,
                                                   args.batch_size,
                                                   n_processes=args.loaderjob)
            for i in chainer.datasets.split_dataset_n_random(
                train, len(devices))
        ]
    else:
        train_iter = iterators.MultiprocessIterator(train,
                                                    args.batch_size,
                                                    n_processes=args.loaderjob)

    test = CamelyonDatasetEx(test_dataset,
                             original_size=image_size,
                             crop_size=crop_size,
                             aug=False,
                             color_aug=False,
                             preprocess_type=preprocess_type)
    test_iter = iterators.MultiprocessIterator(test,
                                               args.batch_size,
                                               repeat=False,
                                               shuffle=False)

    # model construct
    if args.texture:
        model = BilinearCNN(base_cnn=args.arch,
                            pretrained_model='auto',
                            num_class=num_class,
                            texture_layer=None,
                            cbp=args.cbp,
                            cbp_size=4096)
    else:
        model = TrainableCNN(base_cnn=args.arch,
                             pretrained_model='auto',
                             num_class=num_class)

    if args.model_test:
        # test
        # model_path = os.path.join('runs_16', args.model_test, 'models',
        #                           sorted(os.listdir(os.path.join('runs_16', args.model_test, 'models')))[-1])
        # print(model_path)
        # chainer.serializers.load_npz(model_path, model)
        cuda.get_device_from_id(devices[0]).use()
        model.to_gpu()
        with chainer.using_config('train', False), chainer.no_backprop_mode():
            evaluate_ex(model, test_iter, devices[0])
        logger.flush()
        exit()

    if args.resume is not None:
        model_path = os.path.join(
            'runs_16', args.resume, 'models',
            sorted(os.listdir(os.path.join('runs_16', args.resume,
                                           'models')))[-1])
        print(model_path)
        chainer.serializers.load_npz(model_path, model)

    # set optimizer
    optimizer = make_optimizer(model, args.opt, args.lr)

    if len(devices) > 1:
        updater = updaters.MultiprocessParallelUpdater(train_iter,
                                                       optimizer,
                                                       devices=devices)
    else:
        cuda.get_device_from_id(devices[0]).use()
        model.to_gpu()
        # updater
        updater = chainer.training.StandardUpdater(train_iter,
                                                   optimizer,
                                                   device=devices[0])

    # start training
    start = time.time()
    train_loss = 0
    train_accuracy = 0
    while updater.iteration < args.iterations:

        # train
        updater.update()
        progress_report(updater.iteration, start,
                        len(devices) * args.batch_size, len(train))
        train_loss += model.loss.data
        train_accuracy += model.accuracy.data

        if updater.iteration % args.interval == 0:
            logger.plot('train_loss', cuda.to_cpu(train_loss) / args.interval)
            logger.plot('train_accuracy',
                        cuda.to_cpu(train_accuracy) / args.interval)
            train_loss = 0
            train_accuracy = 0

            # test
            with chainer.using_config('train',
                                      False), chainer.no_backprop_mode():
                evaluate_ex(model, test_iter, devices[0])

            # logger
            logger.flush()

            # save
            serializers.save_npz(os.path.join(logger.out_dir, 'resume'),
                                 updater)

            if updater.iteration % 20000 == 0:
                if args.opt == 'adam':
                    optimizer.alpha *= 0.1
                else:
                    optimizer.lr *= 0.1
예제 #2
0
def main(argv):
    del argv
    save, logdir, figname, logHandler = utils.configuration(FLAGS)

    train_ds, test_ds, placeholder = get_dataset(FLAGS)
    loss, correct_prediction, var_list = utils.load_model(FLAGS, placeholder)

    train_iterator = None
    test_iterator = None

    fix_opt, add_opt, stop_opt = utils.make_optimizer(placeholder, loss,
                                                      var_list)
    fix_accuracy, add_accuracy = correct_prediction

    save_dir, save_file = save
    var_all, var_m1, _ = var_list

    epoch_list, original, proposed = [], [], []

    with tf.Session() as sess:
        with tf.device('/cpu:0'):
            merged_summary = tf.summary.merge_all()
            writer = tf.summary.FileWriter(logdir)
            writer.add_graph(sess.graph)
            sess.run(tf.global_variables_initializer())
            saver = tf.train.Saver(var_all)

        print('Learning started. It takes sometimes...')
        print()
        for i in range(1, FLAGS.epochs + 1):
            logHandler.print_epoch()
            if i == (FLAGS.stop_point + 1):
                logHandler._print('Proposed training...')
                loader = tf.train.Saver(var_m1)
                loader.restore(sess, tf.train.latest_checkpoint(save_dir))

            if i <= FLAGS.stop_point:
                if i % FLAGS.iteration == 0:
                    loader = tf.train.Saver(var_all)
                    loader.restore(sess, tf.train.latest_checkpoint(save_dir))

                    utils.fit_model(sess, add_opt, placeholder, train_iterator,
                                    train_ds, i, FLAGS, logHandler,
                                    merged_summary, writer)

                    origin_test_accuracy = utils.test_validate(
                        sess, fix_accuracy, test_iterator, placeholder,
                        test_ds, FLAGS, logHandler)
                    proposed_test_accuracy = utils.test_validate(
                        sess, add_accuracy, test_iterator, placeholder,
                        test_ds, FLAGS, logHandler)

                else:
                    utils.fit_model(sess, fix_opt, placeholder, train_iterator,
                                    train_ds, i, FLAGS, logHandler,
                                    merged_summary, writer)

                    utils.train_validate(sess, fix_accuracy, train_iterator,
                                         placeholder, train_ds, FLAGS,
                                         logHandler)
                    origin_test_accuracy = utils.test_validate(
                        sess, fix_accuracy, test_iterator, placeholder,
                        test_ds, FLAGS, logHandler)

                    proposed_test_accuracy = utils.test_validate(
                        sess, add_accuracy, test_iterator, placeholder,
                        test_ds, FLAGS, logHandler)
                saver.save(sess, save_file)
            else:
                # loader = tf.train.Saver(var_m1)
                # loader.restore(sess, tf.train.latest_checkpoint(save_dir))
                utils.fit_model(sess, stop_opt, placeholder, train_iterator,
                                train_ds, i, FLAGS, logHandler, merged_summary,
                                writer)

                if train_iterator is not None:
                    sess.run(train_iterator.initializer)
                utils.train_validate(sess, add_accuracy, train_iterator,
                                     placeholder, train_ds, FLAGS, logHandler)
                proposed_test_accuracy = utils.test_validate(
                    sess, add_accuracy, test_iterator, placeholder, test_ds,
                    FLAGS, logHandler)

                origin_test_accuracy = utils.test_validate(
                    sess, fix_accuracy, test_iterator, placeholder, test_ds,
                    FLAGS, logHandler)

            epoch_list.append(i)
            proposed.append(proposed_test_accuracy)
            original.append(origin_test_accuracy)

        # Add_final_train_accuracy = tu.train_validate(sess, add_accuracy, train_iterator,
        #                                         X, Y, dropout_rate, train_ds, FLAGS)
        logHandler._print('Original Accuracy: ')
        origin_test_accuracy = utils.test_validate(sess, fix_accuracy,
                                                   test_iterator, placeholder,
                                                   test_ds, FLAGS, logHandler)

        logHandler._print('Proposed Accuracy: ')
        utils.test_validate(sess, add_accuracy, test_iterator, placeholder,
                            test_ds, FLAGS, logHandler)

        plot_acc(epoch_list, original, proposed, figname)
        saver.save(sess, save_file)
        logHandler._print('Training done successfully')
def main():
    parser = argparse.ArgumentParser(description='gpat train')
    parser.add_argument("out")
    parser.add_argument('--resume', default=None)
    parser.add_argument('--log_dir', default='runs_active')
    parser.add_argument('--gpu', '-g', type=int, default=0)
    parser.add_argument('--iterations',
                        default=10**5,
                        type=int,
                        help='number of iterations to learn')
    parser.add_argument('--interval',
                        default=100,
                        type=int,
                        help='number of iterations to evaluate')
    parser.add_argument('--batch_size',
                        '-b',
                        type=int,
                        default=64,
                        help='learning minibatch size')
    parser.add_argument('--lr', type=float, default=1e-4)
    parser.add_argument('--loaderjob', type=int, default=8)
    parser.add_argument('--hed',
                        dest='hed',
                        action='store_true',
                        default=False)
    parser.add_argument('--from_tiff',
                        dest='from_tiff',
                        action='store_true',
                        default=False)
    parser.add_argument('--no-texture',
                        dest='texture',
                        action='store_false',
                        default=True)
    parser.add_argument('--cbp',
                        dest='cbp',
                        action='store_true',
                        default=False)
    parser.add_argument('--no-color_aug',
                        dest='color_aug',
                        action='store_false',
                        default=True)
    parser.add_argument('--model_test', default='', type=str)

    parser.add_argument('--arch',
                        default='googlenet',
                        choices=[
                            'texturecnn', 'resnet', 'googlenet', 'vgg', 'alex',
                            'trained', 'resume'
                        ])
    parser.add_argument('--opt', default='adam', choices=['adam', 'momentum'])
    parser.add_argument('--train_path', default='train_extracted_dataset.pkl')
    parser.add_argument('--test_path', default='test_extracted_dataset.pkl')

    parser.add_argument('--epoch_interval', default=20, type=int)

    parser.add_argument('--active_sample_size', type=int, default=100)
    parser.add_argument('--no-every_init',
                        dest='every_init',
                        action='store_false',
                        default=True)

    parser.add_argument('--random_sample', action='store_true', default=False)
    parser.add_argument('--fixed_ratio', action='store_true', default=False)
    parser.add_argument('--label_init',
                        choices=['random', 'clustering'],
                        default='clustering')
    parser.add_argument('--init_size', default=100, type=int)

    parser.add_argument('--uncertain', action='store_true', default=False)
    parser.add_argument('--uncertain_with_dropout',
                        action='store_true',
                        default=False)
    parser.add_argument('--uncertain_strategy',
                        choices=['entropy', 'least_confident', 'margin'],
                        default='margin')

    parser.add_argument('--clustering', action='store_true', default=False)
    parser.add_argument('--kmeans_cache',
                        default='initial_clustering_result.pkl')
    parser.add_argument('--initial_label_cache',
                        default='initial_label_cache.npy')

    parser.add_argument('--query_by_committee',
                        action='store_true',
                        default=False)
    parser.add_argument('--qbc_strategy',
                        choices=['vote', 'average_kl'],
                        default='average_kl')
    parser.add_argument('--committee_size', default=10, type=int)

    parser.add_argument('--aug_in_inference',
                        action='store_true',
                        default=False)

    args = parser.parse_args()

    device = args.gpu

    # log directory
    logger.init(args)

    # load data
    train_dataset = np.load(os.path.join(dataset_path, args.train_path))
    test_dataset = np.load(os.path.join(dataset_path, args.test_path))
    num_class = 2
    image_size = 256
    crop_size = 224

    preprocess_type = args.arch if not args.hed else 'hed'
    perm = np.random.permutation(len(test_dataset))[:10000]
    test_dataset = [test_dataset[idx] for idx in perm]
    test = CamelyonDatasetEx(test_dataset,
                             original_size=image_size,
                             crop_size=crop_size,
                             aug=False,
                             color_aug=False,
                             preprocess_type=preprocess_type)
    test_iter = iterators.MultiprocessIterator(test,
                                               args.batch_size,
                                               repeat=False,
                                               shuffle=False)

    cbp_feat = np.load('train_cbp512_feat.npy')
    labeled_data, unlabeled_data, feat = initialize_labeled_dataset(
        args, train_dataset, cbp_feat)
    print('now {} labeled samples, {} unlabeled'.format(
        len(labeled_data), len(unlabeled_data)))

    # start training
    reporter = ProgresssReporter(args)
    for iteration in range(100):

        # model construct
        if args.texture:
            model = BilinearCNN(base_cnn=args.arch,
                                pretrained_model='auto',
                                num_class=num_class,
                                texture_layer=None,
                                cbp=args.cbp,
                                cbp_size=4096)
        else:
            model = TrainableCNN(base_cnn=args.arch,
                                 pretrained_model='auto',
                                 num_class=num_class)

        # set optimizer
        optimizer = make_optimizer(model, args.opt, args.lr)

        # use gpu
        cuda.get_device_from_id(device).use()
        model.to_gpu()

        labeled_dataset = CamelyonDatasetEx(labeled_data,
                                            original_size=image_size,
                                            crop_size=crop_size,
                                            aug=True,
                                            color_aug=True,
                                            preprocess_type=preprocess_type)
        labeled_iter = iterators.MultiprocessIterator(labeled_dataset,
                                                      args.batch_size)

        # train phase
        count = 0
        train_loss = 0
        train_acc = 0
        epoch_interval = args.epoch_interval if len(
            labeled_data[0]) < 10000 else args.epoch_interval * 2
        anneal_epoch = int(epoch_interval * 0.8)
        while labeled_iter.epoch < epoch_interval:
            # train with labeled dataset
            batch = labeled_iter.next()
            x, t = chainer.dataset.concat_examples(batch, device=device)
            optimizer.update(model, x, t)
            reporter(labeled_iter.epoch)

            if labeled_iter.is_new_epoch and labeled_iter.epoch == anneal_epoch:
                optimizer.alpha *= 0.1

            if labeled_iter.epoch > args.epoch_interval - 5:
                count += len(batch)
                train_loss += model.loss.data * len(batch)
                train_acc += model.accuracy.data * len(batch)

                # if labeled_iter.is_new_epoch:
                #     train_loss_tmp = cuda.to_cpu(train_loss) / len(labeled_iter.dataset)
                #     loss_history.append(train_loss_tmp - np.sum(loss_history))

        reporter.reset()

        logger.plot('train_loss', cuda.to_cpu(train_loss) / count)
        logger.plot('train_accuracy', cuda.to_cpu(train_acc) / count)

        # test
        print('\ntest')
        with chainer.using_config('train', False), chainer.no_backprop_mode():
            evaluate_ex(model, test_iter, device)

        # logger
        logger.flush()

        if len(labeled_data[0]) >= 10000:
            print('done')
            exit()

        tmp_indices = np.random.permutation(len(unlabeled_data))[:10000]
        tmp_unlabeled_data = [unlabeled_data[idx] for idx in tmp_indices]
        tmp_cbp_feat = cbp_feat[tmp_indices]

        unlabeled_dataset = CamelyonDatasetEx(tmp_unlabeled_data,
                                              original_size=image_size,
                                              crop_size=crop_size,
                                              aug=args.aug_in_inference,
                                              color_aug=args.aug_in_inference,
                                              preprocess_type=preprocess_type)
        unlabeled_iter = iterators.MultiprocessIterator(unlabeled_dataset,
                                                        args.batch_size,
                                                        repeat=False,
                                                        shuffle=False)

        preds = np.zeros((args.committee_size, len(tmp_unlabeled_data), 2))
        # feat = np.zeros((len(unlabeled_iter.dataset), 784))
        if args.random_sample:
            tmp_query_indices = np.random.permutation(
                len(tmp_unlabeled_data))[:args.active_sample_size]
        else:
            loop_num = args.committee_size
            for loop in range(loop_num):
                count = 0
                for batch in unlabeled_iter:
                    x, t = chainer.dataset.concat_examples(batch,
                                                           device=device)
                    with chainer.no_backprop_mode():
                        y = F.softmax(model.forward(x))
                    preds[loop, count:count + len(batch)] = cuda.to_cpu(y.data)
                    count += len(batch)
                    # if loop == 0:
                    #     feat[i * batch_size: (i + 1) * batch_size] = cuda.to_cpu(x)
                unlabeled_iter.reset()
            tmp_query_indices = active_annotation(preds,
                                                  tmp_cbp_feat,
                                                  opt=args)

        # active sampling
        print('active sampling: ', end='')

        if iteration % 10 == 0:
            logger.save(model,
                        [tmp_unlabeled_data[idx] for idx in tmp_query_indices])

        query_indices = tmp_indices[tmp_query_indices]
        labeled_data, unlabeled_data, cbp_feat = query_dataset(
            labeled_data, unlabeled_data, cbp_feat, query_indices)
        print('now {} labeled samples, {} unlabeled'.format(
            len(labeled_data), len(unlabeled_data)))
예제 #4
0
def train(hparams, models_path = './'):
    """

    Returns:
        results: dict
            dictionary containing model identifier, elapsed_time per epoch,
            learning curve with loss and metrics
        models: tuple of keras Models
            the trained encoder and decoder networks


    """

    model_id = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")

    encoder = CNN_Encoder(**hparams['encoder'])
    decoder = RNN_Decoder(**hparams['decoder'], vocab_size=vocab_size)

    optimizer = make_optimizer(**hparams['optimizer'])

    lambda_reg = hparams['train']['lambda_reg']

    # ckpt = tf.train.Checkpoint(encoder=encoder,
    #                            decoder=decoder,
    #                            optimizer = optimizer)
    # ckpt_manager = tf.train.CheckpointManager(ckpt, CHECKPOINT_PATH, max_to_keep=5)


    start_epoch = 0
    # if ckpt_manager.latest_checkpoint:
    #   start_epoch = int(ckpt_manager.latest_checkpoint.split('-')[-1])
    #   # restoring the latest checkpoint in checkpoint_path
    #   ckpt.restore(ckpt_manager.latest_checkpoint)

    @tf.function
    def train_step(img_tensor, target):
        loss = 0
        losses = {}

        batch_size, caption_length = target.shape

        # initializing the hidden state for each batch
        # because the captions are not related from image to image
        hidden = decoder.reset_state(batch_size = batch_size)

        dec_input = tf.expand_dims([tokenizer.word_index['<start>']] * batch_size, 1)
        # attention_plot = tf.Variable(tf.zeros((batch_size,
        #                                      caption_length,
        #                                      attention_features_shape)))


        with tf.GradientTape() as tape:

            features = encoder(img_tensor, training = True)
            attention_sum = tf.zeros((batch_size, attention_features_shape, 1))

            for i in range(1, caption_length):
                # passing the features through the decoder
                predictions, hidden, attention_weights = decoder((dec_input, features, hidden), training = True)
                attention_sum += attention_weights

                # loss += loss_function(target[:, i], predictions)

                # using teacher forcing
                dec_input = tf.expand_dims(target[:, i], 1)

            losses['cross_entropy'] = loss/caption_length

            # attention regularization loss
            loss_attn_reg = lambda_reg * tf.reduce_sum((1 - attention_sum)**2)
            losses['attention_reg'] = loss_attn_reg/caption_length
            loss += loss_attn_reg

            # Weight decay losses
            loss_weight_decay = tf.add_n(encoder.losses) + tf.add_n(decoder.losses)
            losses['weight_decay'] = loss_weight_decay/caption_length
            loss += loss_weight_decay



        losses['total'] = loss/ caption_length

        trainable_variables = encoder.trainable_variables + decoder.trainable_variables

        gradients = tape.gradient(loss, trainable_variables)

        optimizer.apply_gradients(zip(gradients, trainable_variables))

        return loss, losses

    num_steps = num_examples // BATCH_SIZE

    loss_plots = {'cross_entropy':[], 'attention_reg':[], 'weight_decay':[],
                  'total':[]}
    metrics = {'cross-entropy':[], 'bleu-1':[],'bleu-2':[],'bleu-3':[],
               'bleu-4':[], 'meteor':[]}
    epoch_times = []
    val_epoch_times = []

    start = time.time()
    logging.info('Training start for model ' + model_id)
    logging.info('hparams: ' + str(hparams))
    for epoch in range(start_epoch, EPOCHS):
        epoch_start = time.time()
        total_loss = {'cross_entropy':0, 'attention_reg':0, 'weight_decay':0,
                      'total':0}

        for (batch, (img_tensor, target)) in enumerate(dataset_train):
            batch_loss, t_loss = train_step(img_tensor, target)
            for key in total_loss.keys():
                total_loss[key] += float(t_loss[key])

            if batch % 100 == 0:
                logging.info('Epoch {} Batch {} Loss {:.4f}'.format(
                  epoch + 1, batch, batch_loss.numpy() / int(target.shape[1])))

        # storing the epoch end loss value to plot later
        for key in loss_plots.keys():
            loss_plots[key].append(total_loss[key] / num_steps)


        # Evaluate on validation
        val_epoch_start = time.time()
        epoch_scores = validation_scores(dataset_val, (encoder, decoder), tokenizer)
        val_epoch_stop = time.time() - val_epoch_start
        val_epoch_times.append(val_epoch_stop)

        for name, score in epoch_scores.items():
            metrics[name].append(score)

        epoch_stop = time.time() - epoch_start
        epoch_times.append(epoch_stop)

        # if epoch % 1 == 0:
        #   ckpt_manager.save()

        logging.info('Epoch {} Loss {:.6f}'.format(epoch + 1,
                                             total_loss['total']/num_steps))


        logging.info('Time taken for 1 epoch {} sec\n'.format(epoch_stop))

    total_time = time.time() - start
    logging.info('Total training time: {}'.format(total_time))

    results = { 'id':model_id,
                'losses':loss_plots,
                'epoch_times':epoch_times,
                'total_time':total_time,
                'encoder_params': encoder.count_params(),
                'decoder_params': decoder.count_params(),
                'instances_train': num_examples,
                'instances_valid': num_examples_val,
                'batch_size': BATCH_SIZE,
                'epochs': EPOCHS,
                'vocabulary': vocab_size,
                'valid_batch_size': VALID_BATCH_SIZE,
                'valid_epoch_times':val_epoch_times,
                'metrics_val': metrics}

    encoder.save_weights(str(models_path) + ('encoder_' + model_id + '.h5'))
    decoder.save_weights(str(models_path) + ('decoder_' + model_id + '.h5'))
    models = (encoder, decoder)

    return results, models
예제 #5
0
def train(hparams, models_path = './'):
    """

    Returns:
        results: dict
            dictionary containing model identifier, elapsed_time per epoch,
            learning curve with loss and metrics

    """

    model_id = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")

    captioner = Captioner(**hparams['model'],
                          vocab_size = vocab_size,
                          tokenizer = tokenizer,
                          batch_size = BATCH_SIZE,
                          caption_length = train_max_length,
                          valid_batch_size = VALID_BATCH_SIZE,
                          num_examples_val = num_examples_val)

    optimizer = make_optimizer(**hparams['optimizer'])

    metrics = [BLEUMetric(n_gram=1, name = 'bleu-1'),
               BLEUMetric(n_gram=2, name = 'bleu-2'),
               BLEUMetric(n_gram=3, name = 'bleu-3'),
               BLEUMetric(n_gram=4, name = 'bleu-4'),
               METEORMetric(name = 'meteor')]

    captioner.compile(optimizer, loss_fn = padded_cross_entropy,
                      metrics = metrics, run_eagerly = True)

    logger_cb = LoggerCallback()
    early_stopping_cb = EarlyStopping(monitor = 'val_bleu-4', patience = 10,
                                      mode = 'max',
                                      restore_best_weights = True)

    logging.info('Training start for model ' + model_id)
    logging.info('hparams: ' + str(hparams))

    history = captioner.fit(dataset_train, epochs=EPOCHS,
                            validation_data = dataset_val,
                            validation_steps = num_examples_val//VALID_BATCH_SIZE,
                            callbacks=[logger_cb, early_stopping_cb])

    losses = {key:value for key, value in history.history.items() if 'val' not in key}
    metrics = {key[4:]:value for key, value in history.history.items() if 'val' in key}

    results = { 'id': model_id,
                'losses': losses,
                'epoch_times': logger_cb.epoch_times,
                'total_time': logger_cb.total_time,
                'params': captioner.count_params(),
                'instances_train': num_examples,
                'instances_valid': num_examples_val,
                'batch_size': BATCH_SIZE,
                'epochs': EPOCHS,
                'vocabulary': vocab_size,
                'valid_batch_size': VALID_BATCH_SIZE,
                'valid_epoch_times': logger_cb.validation_times,
                'metrics': metrics
                }

    captioner.save_weights(str(models_path / (model_id + '.h5')))

    return results