def construct(args):
    vocab = None

    # Load a dataset
    if args.dataset == 'dbpedia':
        train, test, vocab = text_datasets.get_dbpedia(vocab=vocab)
    elif args.dataset.startswith('imdb.'):
        train, test, vocab = text_datasets.get_imdb(
            fine_grained=args.dataset.endswith('.fine'), vocab=vocab)
    elif args.dataset in [
            'TREC', 'stsa.binary', 'stsa.fine', 'custrev', 'mpqa',
            'rt-polarity', 'subj'
    ]:
        train, test, vocab = text_datasets.get_other_text_dataset(args.dataset,
                                                                  vocab=vocab)

    if not os.path.exists('vocabs'):
        os.mkdir('vocabs')
    vocab_path = os.path.join('vocabs/{}.vocab.json'.format(args.dataset))
    with open(vocab_path, 'w') as f:
        json.dump(vocab, f)
Exemple #2
0
def train(dir="datasets", print_log=False):
    chainer.CHAINER_SEED = args.seed
    numpy.random.seed(args.seed)

    vocab = None

    # Load a dataset
    if args.dataset == 'dbpedia':
        train, test, vocab = text_datasets.get_dbpedia(vocab=vocab)
    elif args.dataset.startswith('imdb.'):
        train, test, vocab = text_datasets.get_imdb(
            fine_grained=args.dataset.endswith('.fine'), vocab=vocab)
    elif args.dataset in [
            'TREC', 'stsa.binary', 'stsa.fine', 'custrev', 'mpqa',
            'rt-polarity', 'subj', 'toxic'
    ]:
        train, test, real_test, vocab = text_datasets.read_text_dataset(
            args.dataset, vocab=None, dir=dir)
        #train, test, vocab = text_datasets.get_other_text_dataset(
        #    args.dataset, vocab=vocab)
    #if args.validation:
    #    real_test = test
    #    dataset_pairs = chainer.datasets.get_cross_validation_datasets_random(
    #        train, 10, seed=777)
    #    train, test = dataset_pairs[0]

    print('# train data: {}'.format(len(train)))
    print('# test  data: {}'.format(len(test)))
    print('# vocab: {}'.format(len(vocab)))
    n_class = len(set([int(d[1]) for d in train]))
    print('# class: {}'.format(n_class))

    chainer.CHAINER_SEED = args.seed
    numpy.random.seed(args.seed)
    train = UnkDropout(train, vocab['<unk>'], 0.01)
    train_iter = chainer.iterators.SerialIterator(train, args.batchsize)
    test_iter = chainer.iterators.SerialIterator(test,
                                                 args.batchsize,
                                                 repeat=False,
                                                 shuffle=False)

    # Setup a model
    chainer.CHAINER_SEED = args.seed
    numpy.random.seed(args.seed)
    if args.model == 'rnn':
        Encoder = class_nets.RNNEncoder
    elif args.model == 'cnn':
        Encoder = class_nets.CNNEncoder
    elif args.model == 'bow':
        Encoder = class_nets.BOWMLPEncoder
    encoder = Encoder(n_layers=args.layer,
                      n_vocab=len(vocab),
                      n_units=args.unit,
                      dropout=args.dropout)
    model = class_nets.TextClassifier(encoder, n_class)

    if args.bilm:
        bilm = bilm_nets.BiLanguageModel(len(vocab), args.bilm_unit,
                                         args.bilm_layer, args.bilm_dropout)
        n_labels = len(set([int(v[1]) for v in test]))
        print('# labels =', n_labels)
        if not args.no_label:
            print('add label')
            bilm.add_label_condition_nets(n_labels, args.bilm_unit)
        else:
            print('not using label')
        chainer.serializers.load_npz(args.bilm, bilm)
        with model.encoder.init_scope():
            initialW = numpy.array(model.encoder.embed.W.data)
            del model.encoder.embed
            model.encoder.embed = bilm_nets.PredictiveEmbed(len(vocab),
                                                            args.unit,
                                                            bilm,
                                                            args.dropout,
                                                            initialW=initialW)
            model.encoder.use_predict_embed = True

            model.encoder.embed.setup(mode=args.bilm_mode,
                                      temp=args.bilm_temp,
                                      word_lower_bound=0.,
                                      gold_lower_bound=0.,
                                      gumbel=args.bilm_gumbel,
                                      residual=args.bilm_residual,
                                      wordwise=args.bilm_wordwise,
                                      add_original=args.bilm_add_original,
                                      augment_ratio=args.bilm_ratio,
                                      ignore_unk=vocab['<unk>'])

    if args.gpu >= 0:
        # Make a specified GPU current
        chainer.cuda.get_device_from_id(args.gpu).use()
        model.to_gpu()  # Copy the model to the GPU
        model.xp.random.seed(args.seed)
    chainer.CHAINER_SEED = args.seed
    numpy.random.seed(args.seed)

    # Setup an optimizer
    optimizer = chainer.optimizers.Adam(args.learning_rate)
    optimizer.setup(model)

    # Set up a trainer
    updater = training.StandardUpdater(train_iter,
                                       optimizer,
                                       converter=convert_seq,
                                       device=args.gpu)

    from triggers import FailMaxValueTrigger
    stop_trigger = FailMaxValueTrigger(key='validation/main/accuracy',
                                       trigger=(1, 'epoch'),
                                       n_times=args.stop_epoch,
                                       max_trigger=args.epoch)
    trainer = training.Trainer(updater, stop_trigger, out=args.out)

    # Evaluate the model with the test dataset for each epoch
    # VALIDATION SET
    trainer.extend(
        MicroEvaluator(test_iter,
                       model,
                       converter=convert_seq,
                       device=args.gpu))

    if args.validation:
        real_test_iter = chainer.iterators.SerialIterator(real_test,
                                                          args.batchsize,
                                                          repeat=False,
                                                          shuffle=False)
    eval_on_real_test = MicroEvaluator(real_test_iter,
                                       model,
                                       converter=convert_seq,
                                       device=args.gpu)
    eval_on_real_test.default_name = 'test'
    trainer.extend(eval_on_real_test)

    # Take a best snapshot
    record_trigger = training.triggers.MaxValueTrigger(
        'validation/main/accuracy', (1, 'epoch'))
    if args.save_model:
        trainer.extend(extensions.snapshot_object(model, 'best_model.npz'),
                       trigger=record_trigger)

    # Write a log of evaluation statistics for each epoch
    out = Outer()
    trainer.extend(extensions.LogReport())
    if print_log:
        trainer.extend(
            extensions.PrintReport([
                'epoch',
                'main/loss',
                'validation/main/loss',
                'main/accuracy',
                'validation/main/accuracy',
                'test/main/loss',
                'test/main/accuracy',
                # 'elapsed_time']))
                'elapsed_time'
            ]),
            trigger=record_trigger)
    else:
        trainer.extend(extensions.PrintReport([
            'main/accuracy', 'validation/main/accuracy', 'test/main/accuracy'
        ],
                                              out=out),
                       trigger=record_trigger)

    # Print a progress bar to stdout
    #trainer.extend(extensions.ProgressBar())

    # Run the training
    trainer.run()

    # free all unused memory blocks “cached” in the memory pool
    mempool = cupy.get_default_memory_pool()
    mempool.free_all_blocks()
    #print("val_acc:{}, test_acc:{}\n", out[-2], out[-1])
    return float(out[-1])
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--gpu',
                        '-g',
                        type=int,
                        default=-1,
                        help='GPU ID (negative value indicates CPU)')
    parser.add_argument('--out',
                        '-o',
                        default='result',
                        help='Directory to output the result')
    parser.add_argument('--batchsize',
                        '-b',
                        type=int,
                        default=32,
                        help='Number of examples in each mini-batch')
    parser.add_argument('--epoch',
                        '-e',
                        type=int,
                        default=5,
                        help='Number of sweeps over the dataset to train')
    parser.add_argument('--gradclip',
                        '-c',
                        type=float,
                        default=10,
                        help='Gradient norm threshold to clip')
    parser.add_argument('--lr', type=float, default=1e-4)
    parser.add_argument('--unit',
                        '-u',
                        type=int,
                        default=1024,
                        help='Number of LSTM units in each layer')
    parser.add_argument('--layer', type=int, default=2)
    parser.add_argument('--dropout', type=float, default=0.5)

    parser.add_argument('--vocab', required=True)
    parser.add_argument('--train-path', '--train')
    parser.add_argument('--valid-path', '--valid')

    parser.add_argument('--resume')

    parser.add_argument('--labeled-dataset',
                        '-ldata',
                        default=None,
                        choices=[
                            'dbpedia', 'imdb.binary', 'imdb.fine', 'TREC',
                            'stsa.binary', 'stsa.fine', 'custrev', 'mpqa',
                            'rt-polarity', 'subj'
                        ],
                        help='Name of dataset.')
    parser.add_argument('--no-label', action='store_true')

    parser.add_argument('--validation', action='store_true')

    args = parser.parse_args()
    print(json.dumps(args.__dict__, indent=2))

    vocab = json.load(open(args.vocab))
    if args.labeled_dataset:
        if args.labeled_dataset == 'dbpedia':
            train, valid, _ = text_datasets.get_dbpedia(vocab=vocab)
        elif args.labeled_dataset.startswith('imdb.'):
            train, valid, _ = text_datasets.get_imdb(
                fine_grained=args.labeled_dataset.endswith('.fine'),
                vocab=vocab)
        elif args.labeled_dataset in [
                'TREC', 'stsa.binary', 'stsa.fine', 'custrev', 'mpqa',
                'rt-polarity', 'subj'
        ]:
            train, valid, _ = text_datasets.get_other_text_dataset(
                args.labeled_dataset, vocab=vocab)

        if args.validation:
            train, valid = \
                chainer.datasets.get_cross_validation_datasets_random(
                    train, 10, seed=777)[0]
        else:
            print('do not use test dataset. pls use validation split.')
    else:
        train = chain_utils.SequenceChainDataset(args.train_path,
                                                 vocab,
                                                 chain_length=1)
        valid = chain_utils.SequenceChainDataset(args.valid_path,
                                                 vocab,
                                                 chain_length=1)

    print('#train =', len(train))
    print('#valid =', len(valid))
    print('#vocab =', len(vocab))

    # Create the dataset iterators
    train_iter = chainer.iterators.SerialIterator(train, args.batchsize)
    valid_iter = chainer.iterators.SerialIterator(valid,
                                                  args.batchsize,
                                                  repeat=False,
                                                  shuffle=False)

    # Prepare an biRNNLM model
    model = nets.BiLanguageModel(len(vocab), args.unit, args.layer,
                                 args.dropout)

    if args.resume:
        print('load {}'.format(args.resume))
        chainer.serializers.load_npz(args.resume, model)

    if args.labeled_dataset and not args.no_label:
        n_labels = len(set([int(v[1]) for v in valid]))
        print('# labels =', n_labels)
        model.add_label_condition_nets(n_labels, args.unit)

    if args.gpu >= 0:
        chainer.cuda.get_device_from_id(args.gpu).use()
        model.to_gpu()

    # Set up an optimizer
    optimizer = chainer.optimizers.Adam(alpha=args.lr)
    optimizer.setup(model)
    optimizer.add_hook(chainer.optimizer.GradientClipping(args.gradclip))

    iter_per_epoch = len(train) // args.batchsize
    print('{} iters per epoch'.format(iter_per_epoch))
    if iter_per_epoch >= 10000:
        log_trigger = (iter_per_epoch // 100, 'iteration')
        eval_trigger = (log_trigger[0] * 50, 'iteration')  # every half epoch
    else:
        log_trigger = (iter_per_epoch // 2, 'iteration')
        eval_trigger = (log_trigger[0] * 2, 'iteration')  # every epoch
    print('log and eval are scheduled at every {} and {}'.format(
        log_trigger, eval_trigger))

    if args.labeled_dataset:
        updater = training.StandardUpdater(
            train_iter,
            optimizer,
            converter=convert_seq,
            device=args.gpu,
            loss_func=model.calculate_loss_with_labels)

        trainer = training.Trainer(updater, (args.epoch, 'epoch'),
                                   out=args.out)
        trainer.extend(extensions.Evaluator(
            valid_iter,
            model,
            converter=convert_seq,
            device=args.gpu,
            eval_func=model.calculate_loss_with_labels),
                       trigger=eval_trigger)
    else:
        updater = training.StandardUpdater(
            train_iter,
            optimizer,
            converter=chain_utils.convert_sequence_chain,
            device=args.gpu,
            loss_func=model.calculate_loss)

        trainer = training.Trainer(updater, (args.epoch, 'epoch'),
                                   out=args.out)
        trainer.extend(extensions.Evaluator(
            valid_iter,
            model,
            converter=chain_utils.convert_sequence_chain,
            device=args.gpu,
            eval_func=model.calculate_loss),
                       trigger=eval_trigger)

    record_trigger = training.triggers.MinValueTrigger('validation/main/perp',
                                                       trigger=eval_trigger)
    trainer.extend(extensions.snapshot_object(model, 'best_model.npz'),
                   trigger=record_trigger)
    trainer.extend(extensions.LogReport(trigger=log_trigger),
                   trigger=log_trigger)
    keys = [
        'epoch', 'iteration', 'main/perp', 'validation/main/perp',
        'elapsed_time'
    ]
    trainer.extend(extensions.PrintReport(keys), trigger=log_trigger)
    trainer.extend(extensions.ProgressBar(update_interval=50))

    print('iter/epoch', iter_per_epoch)
    print('Training start')
    trainer.run()