def observe_hyperparam(name, trigger):
    """
    >>> trainer.extend(observe_hyperparam('alpha', (1, 'epoch')))
    """
    def observer(trainer):
        return trainer.updater.get_optimizer('main').__dict__[name]

    return extensions.observe_value(name, observer, trigger=trigger)
Beispiel #2
0
def main(n, nb_epochs):
    train, test = chainer.datasets.get_cifar10()
    train = Preprocess(train)
    test = Preprocess(test)
    train_iter = chainer.iterators.SerialIterator(train, 128)
    test_iter = chainer.iterators.SerialIterator(test,
                                                 100,
                                                 repeat=False,
                                                 shuffle=False)

    ### Parameters
    device = 0  # -1:CPU, 0:GPU
    ###

    net = chainer.links.Classifier(ResNet(n))

    optimizer = chainer.optimizers.MomentumSGD(lr=0.01, momentum=0.9)
    optimizer.setup(net)
    optimizer.add_hook(chainer.optimizer.WeightDecay(0.0005))

    updater = training.StandardUpdater(train_iter, optimizer, device=device)
    trainer = training.Trainer(updater, (nb_epochs, "epoch"),
                               out=f"chainer_n{n}")

    val_interval = (1, "epoch")
    log_interval = (1, "epoch")

    # 学習率調整
    def lr_shift():
        if updater.epoch == int(nb_epochs * 0.5) or updater.epoch == int(
                nb_epochs * 0.75):
            optimizer.lr *= 0.1
        return optimizer.lr

    trainer.extend(extensions.Evaluator(test_iter, net, device=device),
                   trigger=val_interval)
    trainer.extend(extensions.observe_value("lr", lambda _: lr_shift()),
                   trigger=(1, "epoch"))
    trainer.extend(extensions.LogReport(trigger=log_interval))
    trainer.extend(extensions.PrintReport([
        'elapsed_time',
        'epoch',
        'iteration',
        'main/loss',
        'validation/main/loss',
        'main/accuracy',
        'validation/main/accuracy',
        'lr',
    ]),
                   trigger=log_interval)
    trainer.extend(extensions.ProgressBar(update_interval=50))

    trainer.run()
Beispiel #3
0
def main():
    parser = argparse.ArgumentParser(
        description='Chainer example: convolutional seq2seq')
    parser.add_argument('--batchsize',
                        '-b',
                        type=int,
                        default=48,
                        help='Number of images in each mini-batch')
    parser.add_argument('--epoch',
                        '-e',
                        type=int,
                        default=100,
                        help='Number of sweeps over the dataset to train')
    parser.add_argument('--gpu',
                        '-g',
                        type=int,
                        default=-1,
                        help='GPU ID (negative value indicates CPU)')
    parser.add_argument('--unit',
                        '-u',
                        type=int,
                        default=512,
                        help='Number of units')
    parser.add_argument('--layer',
                        '-l',
                        type=int,
                        default=6,
                        help='Number of layers')
    parser.add_argument('--head',
                        type=int,
                        default=8,
                        help='Number of heads in attention mechanism')
    parser.add_argument('--dropout',
                        '-d',
                        type=float,
                        default=0.1,
                        help='Dropout rate')
    parser.add_argument('--input',
                        '-i',
                        type=str,
                        default='./',
                        help='Input directory')
    parser.add_argument('--source',
                        '-s',
                        type=str,
                        default='europarl-v7.fr-en.en',
                        help='Filename of train data for source language')
    parser.add_argument('--target',
                        '-t',
                        type=str,
                        default='europarl-v7.fr-en.fr',
                        help='Filename of train data for target language')
    parser.add_argument('--source-valid',
                        '-svalid',
                        type=str,
                        default='dev/newstest2013.en',
                        help='Filename of validation data for source language')
    parser.add_argument('--target-valid',
                        '-tvalid',
                        type=str,
                        default='dev/newstest2013.fr',
                        help='Filename of validation data for target language')
    parser.add_argument('--out',
                        '-o',
                        default='result',
                        help='Directory to output the result')
    parser.add_argument('--source-vocab',
                        type=int,
                        default=40000,
                        help='Vocabulary size of source language')
    parser.add_argument('--target-vocab',
                        type=int,
                        default=40000,
                        help='Vocabulary size of target language')
    parser.add_argument('--no-bleu',
                        '-no-bleu',
                        action='store_true',
                        help='Skip BLEU calculation')
    parser.add_argument('--use-label-smoothing',
                        action='store_true',
                        help='Use label smoothing for cross entropy')
    parser.add_argument('--embed-position',
                        action='store_true',
                        help='Use position embedding rather than sinusoid')
    parser.add_argument('--use-fixed-lr',
                        action='store_true',
                        help='Use fixed learning rate rather than the ' +
                        'annealing proposed in the paper')
    args = parser.parse_args()
    print(json.dumps(args.__dict__, indent=4))

    # Check file
    en_path = os.path.join(args.input, args.source)
    source_vocab = ['<eos>', '<unk>', '<bos>'] + \
        preprocess.count_words(en_path, args.source_vocab)
    source_data = preprocess.make_dataset(en_path, source_vocab)
    fr_path = os.path.join(args.input, args.target)
    target_vocab = ['<eos>', '<unk>', '<bos>'] + \
        preprocess.count_words(fr_path, args.target_vocab)
    target_data = preprocess.make_dataset(fr_path, target_vocab)
    assert len(source_data) == len(target_data)
    print('Original training data size: %d' % len(source_data))
    train_data = [(s, t) for s, t in six.moves.zip(source_data, target_data)
                  if 0 < len(s) < 50 and 0 < len(t) < 50]
    print('Filtered training data size: %d' % len(train_data))

    en_path = os.path.join(args.input, args.source_valid)
    source_data = preprocess.make_dataset(en_path, source_vocab)
    fr_path = os.path.join(args.input, args.target_valid)
    target_data = preprocess.make_dataset(fr_path, target_vocab)
    assert len(source_data) == len(target_data)
    test_data = [(s, t) for s, t in six.moves.zip(source_data, target_data)
                 if 0 < len(s) and 0 < len(t)]

    source_ids = {word: index for index, word in enumerate(source_vocab)}
    target_ids = {word: index for index, word in enumerate(target_vocab)}

    target_words = {i: w for w, i in target_ids.items()}
    source_words = {i: w for w, i in source_ids.items()}

    # Define Model
    model = net.Transformer(args.layer,
                            min(len(source_ids), len(source_words)),
                            min(len(target_ids), len(target_words)),
                            args.unit,
                            h=args.head,
                            dropout=args.dropout,
                            max_length=500,
                            use_label_smoothing=args.use_label_smoothing,
                            embed_position=args.embed_position)
    if args.gpu >= 0:
        chainer.cuda.get_device(args.gpu).use()
        model.to_gpu(args.gpu)

    # Setup Optimizer
    optimizer = chainer.optimizers.Adam(alpha=5e-5,
                                        beta1=0.9,
                                        beta2=0.98,
                                        eps=1e-9)
    optimizer.setup(model)

    # Setup Trainer
    train_iter = chainer.iterators.SerialIterator(train_data, args.batchsize)
    test_iter = chainer.iterators.SerialIterator(test_data,
                                                 args.batchsize,
                                                 repeat=False,
                                                 shuffle=False)
    iter_per_epoch = len(train_data) // args.batchsize
    print('Number of iter/epoch =', iter_per_epoch)

    updater = training.StandardUpdater(train_iter,
                                       optimizer,
                                       converter=seq2seq_pad_concat_convert,
                                       device=args.gpu)

    trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out)

    # If you want to change a logging interval, change this number
    log_trigger = (min(200, iter_per_epoch // 2), 'iteration')

    def floor_step(trigger):
        floored = trigger[0] - trigger[0] % log_trigger[0]
        if floored <= 0:
            floored = trigger[0]
        return (floored, trigger[1])

    # Validation every half epoch
    eval_trigger = floor_step((iter_per_epoch // 2, 'iteration'))
    record_trigger = training.triggers.MinValueTrigger('val/main/perp',
                                                       eval_trigger)

    evaluator = extensions.Evaluator(test_iter,
                                     model,
                                     converter=seq2seq_pad_concat_convert,
                                     device=args.gpu)
    evaluator.default_name = 'val'
    trainer.extend(evaluator, trigger=eval_trigger)

    # Use Vaswan's magical rule of learning rate(Eq. 3 in the paper)
    # But, the hyperparamter in the paper seems to work well
    # only with a large batchsize.
    # If you run on popular setup (e.g. size=48 on 1 GPU),
    # you may have to change the hyperparamter.
    # I scaled learning rate by 0.5 consistently.
    # ("scale" is always multiplied to learning rate.)

    # If you use a shallow layer network (<=2),
    # you may not have to change it from the paper setting.
    if not args.use_fixed_lr:
        trainer.extend(
            # VaswaniRule('alpha', d=args.unit, warmup_steps=4000, scale=1.),
            # VaswaniRule('alpha', d=args.unit, warmup_steps=32000, scale=1.),
            # VaswaniRule('alpha', d=args.unit, warmup_steps=4000, scale=0.5),
            # VaswaniRule('alpha', d=args.unit, warmup_steps=16000, scale=1.),
            VaswaniRule('alpha', d=args.unit, warmup_steps=64000, scale=1.),
            trigger=(1, 'iteration'))
    observe_alpha = extensions.observe_value(
        'alpha', lambda trainer: trainer.updater.get_optimizer('main').alpha)
    trainer.extend(observe_alpha, trigger=(1, 'iteration'))

    # Only if a model gets best validation score,
    # save (overwrite) the model
    trainer.extend(extensions.snapshot_object(model, 'best_model.npz'),
                   trigger=record_trigger)

    def translate_one(source, target):
        words = preprocess.split_sentence(source)
        print('# source : ' + ' '.join(words))
        x = model.xp.array([source_ids.get(w, 1) for w in words], 'i')
        ys = model.translate([x], beam=5)[0]
        words = [target_words[y] for y in ys]
        print('#  result : ' + ' '.join(words))
        print('#  expect : ' + target)

    @chainer.training.make_extension(trigger=(200, 'iteration'))
    def translate(trainer):
        translate_one('Who are we ?', 'Qui sommes-nous?')
        translate_one(
            'And it often costs over a hundred dollars ' +
            'to obtain the required identity card .',
            'Or, il en coûte souvent plus de cent dollars ' +
            'pour obtenir la carte d\'identité requise.')

        source, target = test_data[numpy.random.choice(len(test_data))]
        source = ' '.join([source_words[i] for i in source])
        target = ' '.join([target_words[i] for i in target])
        translate_one(source, target)

    # Gereneration Test
    trainer.extend(translate, trigger=(min(200, iter_per_epoch), 'iteration'))

    # Calculate BLEU every half epoch
    if not args.no_bleu:
        trainer.extend(CalculateBleu(model,
                                     test_data,
                                     'val/main/bleu',
                                     device=args.gpu,
                                     batch=args.batchsize // 4),
                       trigger=floor_step((iter_per_epoch // 2, 'iteration')))

    # Log
    trainer.extend(extensions.LogReport(trigger=log_trigger),
                   trigger=log_trigger)
    trainer.extend(extensions.PrintReport([
        'epoch', 'iteration', 'main/loss', 'val/main/loss', 'main/perp',
        'val/main/perp', 'main/acc', 'val/main/acc', 'val/main/bleu', 'alpha',
        'elapsed_time'
    ]),
                   trigger=log_trigger)

    print('start training')
    trainer.run()
Beispiel #4
0
def train(args):
    """Train with the given args

    :param Namespace args: The program arguments
    """
    set_deterministic_pytorch(args)

    # check cuda availability
    if not torch.cuda.is_available():
        logging.warning('cuda is not available')

    # get input and output dimension info
    with open(args.valid_json, 'rb') as f:
        valid_json = json.load(f)['utts']
    utts = list(valid_json.keys())
    idim = int(valid_json[utts[0]]['input'][0]['shape'][1])
    odim = int(valid_json[utts[0]]['output'][0]['shape'][1])
    logging.info('#input dims : ' + str(idim))
    logging.info('#output dims: ' + str(odim))

    # specify attention, CTC, hybrid mode
    if args.mtlalpha == 1.0:
        mtl_mode = 'ctc'
        logging.info('Pure CTC mode')
    elif args.mtlalpha == 0.0:
        mtl_mode = 'att'
        logging.info('Pure attention mode')
    else:
        mtl_mode = 'mtl'
        logging.info('Multitask learning mode')

    # specify model architecture
    model = E2E(idim, odim, args)
    subsampling_factor = model.subsample[0]

    if args.rnnlm is not None:
        rnnlm_args = get_model_conf(args.rnnlm, args.rnnlm_conf)
        rnnlm = lm_pytorch.ClassifierWithState(
            lm_pytorch.RNNLM(
                len(args.char_list), rnnlm_args.layer, rnnlm_args.unit))
        torch.load(args.rnnlm, rnnlm)
        model.rnnlm = rnnlm

    # write model config
    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)
    model_conf = args.outdir + '/model.json'
    with open(model_conf, 'wb') as f:
        logging.info('writing a model config file to ' + model_conf)
        f.write(json.dumps((idim, odim, vars(args)), indent=4, sort_keys=True).encode('utf_8'))
    for key in sorted(vars(args).keys()):
        logging.info('ARGS: ' + key + ': ' + str(vars(args)[key]))

    reporter = model.reporter

    # check the use of multi-gpu
    if args.ngpu > 1:
        model = torch.nn.DataParallel(model, device_ids=list(range(args.ngpu)))
        logging.info('batch size is automatically increased (%d -> %d)' % (
            args.batch_size, args.batch_size * args.ngpu))
        args.batch_size *= args.ngpu

    # set torch device
    device = torch.device("cuda" if args.ngpu > 0 else "cpu")
    model = model.to(device)

    # Setup an optimizer
    if args.opt == 'adadelta':
        optimizer = torch.optim.Adadelta(
            model.parameters(), rho=0.95, eps=args.eps,
            weight_decay=args.weight_decay)
    elif args.opt == 'adam':
        optimizer = torch.optim.Adam(model.parameters(),
                                     weight_decay=args.weight_decay)

    # FIXME: TOO DIRTY HACK
    setattr(optimizer, "target", reporter)
    setattr(optimizer, "serialize", lambda s: reporter.serialize(s))

    # Setup a converter
    converter = CustomConverter(subsampling_factor=subsampling_factor,
                                preprocess_conf=args.preprocess_conf)

    # read json data
    with open(args.train_json, 'rb') as f:
        train_json = json.load(f)['utts']
    with open(args.valid_json, 'rb') as f:
        valid_json = json.load(f)['utts']

    # make minibatch list (variable length)
    train = make_batchset(train_json, args.batch_size,
                          args.maxlen_in, args.maxlen_out, args.minibatches,
                          min_batch_size=args.ngpu if args.ngpu > 1 else 1)
    valid = make_batchset(valid_json, args.batch_size,
                          args.maxlen_in, args.maxlen_out, args.minibatches,
                          min_batch_size=args.ngpu if args.ngpu > 1 else 1)
    # hack to make batchsize argument as 1
    # actual bathsize is included in a list
    if args.n_iter_processes > 0:
        train_iter = chainer.iterators.MultiprocessIterator(
            TransformDataset(train, converter.transform),
            batch_size=1, n_processes=args.n_iter_processes, n_prefetch=8, maxtasksperchild=20)
        valid_iter = chainer.iterators.MultiprocessIterator(
            TransformDataset(valid, converter.transform),
            batch_size=1, repeat=False, shuffle=False,
            n_processes=args.n_iter_processes, n_prefetch=8, maxtasksperchild=20)
    else:
        train_iter = chainer.iterators.SerialIterator(
            TransformDataset(train, converter.transform),
            batch_size=1)
        valid_iter = chainer.iterators.SerialIterator(
            TransformDataset(valid, converter.transform),
            batch_size=1, repeat=False, shuffle=False)

    # Set up a trainer
    updater = CustomUpdater(
        model, args.grad_clip, train_iter, optimizer, converter, device, args.ngpu)
    trainer = training.Trainer(
        updater, (args.epochs, 'epoch'), out=args.outdir)

    # Resume from a snapshot
    if args.resume:
        logging.info('resumed from %s' % args.resume)
        torch_resume(args.resume, trainer)

    # Evaluate the model with the test dataset for each epoch
    trainer.extend(CustomEvaluator(model, valid_iter, reporter, converter, device))

    # Save attention weight each epoch
    if args.num_save_attention > 0 and args.mtlalpha != 1.0:
        data = sorted(list(valid_json.items())[:args.num_save_attention],
                      key=lambda x: int(x[1]['input'][0]['shape'][1]), reverse=True)
        if hasattr(model, "module"):
            att_vis_fn = model.module.calculate_all_attentions
        else:
            att_vis_fn = model.calculate_all_attentions
        att_reporter = PlotAttentionReport(
            att_vis_fn, data, args.outdir + "/att_ws",
            converter=converter, device=device)
        trainer.extend(att_reporter, trigger=(1, 'epoch'))
    else:
        att_reporter = None

    # Make a plot for training and validation values
    trainer.extend(extensions.PlotReport(['main/loss', 'validation/main/loss',
                                          'main/loss_ctc', 'validation/main/loss_ctc',
                                          'main/loss_att', 'validation/main/loss_att'],
                                         'epoch', file_name='loss.png'))
    trainer.extend(extensions.PlotReport(['main/acc', 'validation/main/acc'],
                                         'epoch', file_name='acc.png'))

    # Save best models
    trainer.extend(extensions.snapshot_object(model, 'model.loss.best', savefun=torch_save),
                   trigger=training.triggers.MinValueTrigger('validation/main/loss'))
    if mtl_mode is not 'ctc':
        trainer.extend(extensions.snapshot_object(model, 'model.acc.best', savefun=torch_save),
                       trigger=training.triggers.MaxValueTrigger('validation/main/acc'))

    # save snapshot which contains model and optimizer states
    trainer.extend(torch_snapshot(), trigger=(1, 'epoch'))

    # epsilon decay in the optimizer
    if args.opt == 'adadelta':
        if args.criterion == 'acc' and mtl_mode is not 'ctc':
            trainer.extend(restore_snapshot(model, args.outdir + '/model.acc.best', load_fn=torch_load),
                           trigger=CompareValueTrigger(
                               'validation/main/acc',
                               lambda best_value, current_value: best_value > current_value))
            trainer.extend(adadelta_eps_decay(args.eps_decay),
                           trigger=CompareValueTrigger(
                               'validation/main/acc',
                               lambda best_value, current_value: best_value > current_value))
        elif args.criterion == 'loss':
            trainer.extend(restore_snapshot(model, args.outdir + '/model.loss.best', load_fn=torch_load),
                           trigger=CompareValueTrigger(
                               'validation/main/loss',
                               lambda best_value, current_value: best_value < current_value))
            trainer.extend(adadelta_eps_decay(args.eps_decay),
                           trigger=CompareValueTrigger(
                               'validation/main/loss',
                               lambda best_value, current_value: best_value < current_value))

    # Write a log of evaluation statistics for each epoch
    trainer.extend(extensions.LogReport(trigger=(REPORT_INTERVAL, 'iteration')))
    report_keys = ['epoch', 'iteration', 'main/loss', 'main/loss_ctc', 'main/loss_att',
                   'validation/main/loss', 'validation/main/loss_ctc', 'validation/main/loss_att',
                   'main/acc', 'validation/main/acc', 'elapsed_time']
    if args.opt == 'adadelta':
        trainer.extend(extensions.observe_value(
            'eps', lambda trainer: trainer.updater.get_optimizer('main').param_groups[0]["eps"]),
            trigger=(REPORT_INTERVAL, 'iteration'))
        report_keys.append('eps')
    if args.report_cer:
        report_keys.append('validation/main/cer')
    if args.report_wer:
        report_keys.append('validation/main/wer')
    trainer.extend(extensions.PrintReport(
        report_keys), trigger=(REPORT_INTERVAL, 'iteration'))

    trainer.extend(extensions.ProgressBar(update_interval=REPORT_INTERVAL))
    set_early_stop(trainer, args)

    if args.tensorboard_dir is not None and args.tensorboard_dir != "":
        writer = SummaryWriter(args.tensorboard_dir)
        trainer.extend(TensorboardLogger(writer, att_reporter))
    # Run the training
    trainer.run()
    check_early_stop(trainer, args.epochs)
Beispiel #5
0
def train(args):
    '''Run training'''
    # seed setting
    torch.manual_seed(args.seed)

    # debug mode setting
    # 0 would be fastest, but 1 seems to be reasonable
    # by considering reproducability
    # revmoe type check
    if args.debugmode < 2:
        chainer.config.type_check = False
        logging.info('torch type check is disabled')
    # use determinisitic computation or not
    if args.debugmode < 1:
        torch.backends.cudnn.deterministic = False
        logging.info('torch cudnn deterministic is disabled')
    else:
        torch.backends.cudnn.deterministic = True

    # check cuda availability
    if not torch.cuda.is_available():
        logging.warning('cuda is not available')

    # get input and output dimension info
    with open(args.valid_json, 'rb') as f:
        valid_json = json.load(f)['utts']
    utts = list(valid_json.keys())
    # TODO(nelson) remove in future
    if 'input' not in valid_json[utts[0]]:
        logging.error("input file format (json) is modified, please redo"
                      "stage 2: Dictionary and Json Data Preparation")
        sys.exit(1)
    idim = int(valid_json[utts[0]]['input'][0]['shape'][1])
    odim = int(valid_json[utts[0]]['output'][0]['shape'][1])
    logging.info('#input dims : ' + str(idim))
    logging.info('#output dims: ' + str(odim))

    # specify attention, CTC, hybrid mode
    if args.mtlalpha == 1.0:
        mtl_mode = 'ctc'
        logging.info('Pure CTC mode')
    elif args.mtlalpha == 0.0:
        mtl_mode = 'att'
        logging.info('Pure attention mode')
    else:
        mtl_mode = 'mtl'
        logging.info('Multitask learning mode')

    # specify model architecture
    e2e = E2E(idim, odim, args)
    model = Loss(e2e, args.mtlalpha)

    # write model config
    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)
    model_conf = args.outdir + '/model.conf'
    with open(model_conf, 'wb') as f:
        logging.info('writing a model config file to' + model_conf)
        # TODO(watanabe) use others than pickle, possibly json, and save as a text
        pickle.dump((idim, odim, args), f)
    for key in sorted(vars(args).keys()):
        logging.info('ARGS: ' + key + ': ' + str(vars(args)[key]))

    # Set gpu
    reporter = model.reporter
    ngpu = args.ngpu
    if ngpu == 1:
        gpu_id = range(ngpu)
        logging.info('gpu id: ' + str(gpu_id))
        model.cuda()
    elif ngpu > 1:
        gpu_id = range(ngpu)
        logging.info('gpu id: ' + str(gpu_id))
        model = DataParallel(model, device_ids=gpu_id)
        model.cuda()
        logging.info('batch size is automatically increased (%d -> %d)' %
                     (args.batch_size, args.batch_size * args.ngpu))
        args.batch_size *= args.ngpu
    else:
        gpu_id = [-1]

    # Setup an optimizer
    if args.opt == 'adadelta':
        optimizer = torch.optim.Adadelta(model.parameters(),
                                         rho=0.95,
                                         eps=args.eps)
    elif args.opt == 'adam':
        optimizer = torch.optim.Adam(model.parameters())

    # FIXME: TOO DIRTY HACK
    setattr(optimizer, "target", reporter)
    setattr(optimizer, "serialize", lambda s: reporter.serialize(s))

    # read json data
    with open(args.train_json, 'rb') as f:
        train_json = json.load(f)['utts']
    with open(args.valid_json, 'rb') as f:
        valid_json = json.load(f)['utts']

    # make minibatch list (variable length)
    train = make_batchset(train_json, args.batch_size, args.maxlen_in,
                          args.maxlen_out, args.minibatches)
    valid = make_batchset(valid_json, args.batch_size, args.maxlen_in,
                          args.maxlen_out, args.minibatches)
    # hack to make batchsze argument as 1
    # actual bathsize is included in a list
    train_iter = chainer.iterators.SerialIterator(train, 1)
    valid_iter = chainer.iterators.SerialIterator(valid,
                                                  1,
                                                  repeat=False,
                                                  shuffle=False)

    # converter choice
    converter = converter_kaldi
    if args.input_tensor:
        converter = converter_tensor

    # Set up a trainer
    updater = PytorchSeqUpdaterKaldi(model,
                                     args.grad_clip,
                                     train_iter,
                                     optimizer,
                                     converter=converter,
                                     device=gpu_id)
    trainer = training.Trainer(updater, (args.epochs, 'epoch'),
                               out=args.outdir)

    # Resume from a snapshot
    if args.resume:
        chainer.serializers.load_npz(args.resume, trainer)
        if ngpu > 1:
            model.module.load_state_dict(
                torch.load(args.outdir + '/model.acc.best'))
        else:
            model.load_state_dict(torch.load(args.outdir + '/model.acc.best'))
        model = trainer.updater.model

    # Evaluate the model with the test dataset for each epoch
    trainer.extend(
        PytorchSeqEvaluaterKaldi(model,
                                 valid_iter,
                                 reporter,
                                 converter=converter,
                                 device=gpu_id))

    # Save attention weight each epoch
    if args.num_save_attention > 0 and args.mtlalpha != 1.0:
        data = sorted(list(valid_json.items())[:args.num_save_attention],
                      key=lambda x: int(x[1]['input'][0]['shape'][1]),
                      reverse=True)
        data = converter([data], device=gpu_id)
        trainer.extend(PlotAttentionReport(model, data,
                                           args.outdir + "/att_ws"),
                       trigger=(1, 'epoch'))

    # Make a plot for training and validation values
    trainer.extend(
        extensions.PlotReport([
            'main/loss', 'validation/main/loss', 'main/loss_ctc',
            'validation/main/loss_ctc', 'main/loss_att',
            'validation/main/loss_att'
        ],
                              'epoch',
                              file_name='loss.png'))
    trainer.extend(
        extensions.PlotReport(['main/acc', 'validation/main/acc'],
                              'epoch',
                              file_name='acc.png'))

    # Save best models
    def torch_save(path, _):
        if ngpu > 1:
            torch.save(model.module.state_dict(), path)
            torch.save(model.module, path + ".pkl")
        else:
            torch.save(model.state_dict(), path)
            torch.save(model, path + ".pkl")

    trainer.extend(
        extensions.snapshot_object(model,
                                   'model.loss.best',
                                   savefun=torch_save),
        trigger=training.triggers.MinValueTrigger('validation/main/loss'))
    # save the model after each epoch
    trainer.extend(extensions.snapshot_object(model,
                                              'model_epoch{.updater.epoch}',
                                              savefun=torch_save),
                   trigger=(1, 'epoch'))

    if mtl_mode is not 'ctc':
        trainer.extend(
            extensions.snapshot_object(model,
                                       'model.acc.best',
                                       savefun=torch_save),
            trigger=training.triggers.MaxValueTrigger('validation/main/acc'))

    # epsilon decay in the optimizer
    def torch_load(path, obj):
        if ngpu > 1:
            model.module.load_state_dict(torch.load(path))
        else:
            model.load_state_dict(torch.load(path))
        return obj

    if args.opt == 'adadelta':
        if args.criterion == 'acc' and mtl_mode is not 'ctc':
            trainer.extend(restore_snapshot(model,
                                            args.outdir + '/model.acc.best',
                                            load_fn=torch_load),
                           trigger=CompareValueTrigger(
                               'validation/main/acc', lambda best_value,
                               current_value: best_value > current_value))
            trainer.extend(adadelta_eps_decay(args.eps_decay),
                           trigger=CompareValueTrigger(
                               'validation/main/acc', lambda best_value,
                               current_value: best_value > current_value))
        elif args.criterion == 'loss':
            trainer.extend(restore_snapshot(model,
                                            args.outdir + '/model.loss.best',
                                            load_fn=torch_load),
                           trigger=CompareValueTrigger(
                               'validation/main/loss', lambda best_value,
                               current_value: best_value < current_value))
            trainer.extend(adadelta_eps_decay(args.eps_decay),
                           trigger=CompareValueTrigger(
                               'validation/main/loss', lambda best_value,
                               current_value: best_value < current_value))

    # Write a log of evaluation statistics for each epoch
    trainer.extend(extensions.LogReport(trigger=(100, 'iteration')))
    report_keys = [
        'epoch', 'iteration', 'main/loss', 'main/loss_ctc', 'main/loss_att',
        'validation/main/loss', 'validation/main/loss_ctc',
        'validation/main/loss_att', 'main/acc', 'validation/main/acc',
        'elapsed_time'
    ]
    if args.opt == 'adadelta':
        trainer.extend(extensions.observe_value(
            'eps', lambda trainer: trainer.updater.get_optimizer('main').
            param_groups[0]["eps"]),
                       trigger=(100, 'iteration'))
        report_keys.append('eps')
    trainer.extend(extensions.PrintReport(report_keys),
                   trigger=(100, 'iteration'))

    trainer.extend(extensions.ProgressBar())

    # Run the training
    trainer.run()
Beispiel #6
0
def main(args):
    assert((args.depth - args.block - 1) % args.block == 0)
    n_layer = int((args.depth - args.block - 1) / args.block)
    if args.dataset == 'cifar10':
        train, test = cifar.get_cifar10()
        n_class = 10
    elif args.dataset == 'cifar100':
        train, test = cifar.get_cifar100()
        n_class = 100
    elif args.dataset == 'SVHN':
        raise NotImplementedError()

    images = convert.concat_examples(train)[0]
    mean = images.mean(axis=(0, 2, 3))
    std = images.std(axis=(0, 2, 3))

    train = PreprocessedDataset(train, mean, std, random=args.augment)
    test = PreprocessedDataset(test, mean, std)

    train_iter = chainer.iterators.SerialIterator(
        train, args.batchsize / args.split_size)
    test_iter = chainer.iterators.SerialIterator(
        test, args.batchsize / args.split_size, repeat=False, shuffle=False)

    model = chainer.links.Classifier(DenseNet(
        n_layer, args.growth_rate, n_class, args.drop_ratio, 16, args.block))
    if args.init_model:
        serializers.load_npz(args.init_model, model)
    chainer.cuda.get_device(args.gpu).use()
    model.to_gpu()

    optimizer = chainer.optimizers.NesterovAG(lr=args.lr, momentum=0.9)
    optimizer.setup(model)
    optimizer.add_hook(chainer.optimizer.WeightDecay(args.weight_decay))

    updater = StandardUpdater(
        train_iter, optimizer, (args.split_size, 'mean'), device=args.gpu)
    trainer = training.Trainer(updater, (300, 'epoch'), out=args.dir)

    val_interval = (1, 'epoch')
    log_interval = (1, 'epoch')

    def lr_shift():  # DenseNet specific!
        if updater.epoch == 150 or updater.epoch == 225:
            optimizer.lr *= 0.1
        return optimizer.lr

    trainer.extend(Evaluator(
        test_iter, model, device=args.gpu), trigger=val_interval)
    trainer.extend(extensions.observe_value(
        'lr', lambda _: lr_shift()), trigger=(1, 'epoch'))
    trainer.extend(extensions.dump_graph('main/loss'))
    trainer.extend(extensions.snapshot_object(
        model, 'epoch_{.updater.epoch}.model'), trigger=val_interval)
    trainer.extend(extensions.snapshot_object(
        optimizer, 'epoch_{.updater.epoch}.state'), trigger=val_interval)
    trainer.extend(extensions.LogReport(trigger=log_interval))
    start_time = time.time()
    trainer.extend(extensions.observe_value(
        'time', lambda _: time.time() - start_time), trigger=log_interval)
    trainer.extend(extensions.PrintReport([
        'time', 'epoch', 'iteration', 'main/loss', 'validation/main/loss',
        'main/accuracy', 'validation/main/accuracy', 'lr',
    ]), trigger=log_interval)
    trainer.extend(extensions.observe_value(
        'graph', lambda _: create_fig(args.dir)),
        trigger=(1, 'epoch'), priority=50)
    trainer.extend(extensions.ProgressBar(update_interval=10))

    trainer.run()
Beispiel #7
0
def main(args):

    assert((args.depth - args.block - 1) % args.block == 0)
    n_layer = (args.depth - args.block - 1) / args.block
    
    if args.dataset == 'cifar10':
        mean = numpy.asarray((125.3,123.0,113.9))#from fb.resnet.torch
        std = numpy.asarray((63.0, 62.1, 66.7))# Did the std data computed from 0 padding images?
        train, test = dataset.EXget_cifar10(scale=255,mean=mean,std=std)
        
        n_class = 10
    elif args.dataset == 'cifar100':
        mean = numpy.asarray((129.3,124.1,112.4))#from fb.resnet.torch
        std = numpy.asarray((68.2, 65.4, 70.4))
        train, test = dataset.EXget_cifar100(scale=255,mean=mean,std=std)
        
        n_class = 100
    elif args.dataset == 'SVHN':
        raise NotImplementedError()

    train = PreprocessedDataset(train, random=True)
    test = PreprocessedDataset(test)

    train_iter = chainer.iterators.MultiprocessIterator(train, args.batchsize)
    test_iter = chainer.iterators.MultiprocessIterator(
        test, args.batchsize, repeat=False, shuffle=False)

    model = chainer.links.Classifier(DenseNet(n_layer, args.growth_rate, n_class, args.drop_ratio, 16, args.block))
    if args.init_model:
        serializers.load_npz(args.init_model, model)

    import EXoptimizers
    optimizer = EXoptimizers.originalNesterovAG(lr=args.lr / len(args.gpus), momentum=0.9)
    optimizer.setup(model)
    optimizer.add_hook(chainer.optimizer.WeightDecay(args.weight_decay))

    devices = {'main': args.gpus[0]}
    if len(args.gpus) > 1:
        for gid in args.gpus[1:]:
            devices['gpu%d' % gid] = gid
    updater = training.ParallelUpdater(train_iter, optimizer, devices=devices)
    trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.dir)

    val_interval = (1, 'epoch')
    log_interval = (1, 'epoch')

    def lr_shift():  # DenseNet specific!
        if updater.epoch == 151 or updater.epoch == 226:
            optimizer.lr *= 0.1
        return optimizer.lr

    trainer.extend(Evaluator(
        test_iter, model, device=args.gpus[0]), trigger=val_interval)
    trainer.extend(extensions.observe_value(
        'lr', lambda _: lr_shift()), trigger=(1, 'epoch'))
    trainer.extend(extensions.dump_graph('main/loss'))
    trainer.extend(extensions.snapshot_object(
        model, 'epoch_{.updater.epoch}.model'), trigger=val_interval)
    trainer.extend(extensions.snapshot_object(
        optimizer, 'epoch_{.updater.epoch}.state'), trigger=val_interval)
    trainer.extend(extensions.LogReport(trigger=log_interval))
    trainer.extend(extensions.observe_lr(), trigger=log_interval)
    start_time = time.time()
    trainer.extend(extensions.observe_value(
        'time', lambda _: time.time() - start_time), trigger=log_interval)
    trainer.extend(extensions.PrintReport([
        'time', 'epoch', 'iteration', 'main/loss', 'validation/main/loss',
        'main/accuracy', 'validation/main/accuracy', 'lr',
    ]), trigger=log_interval)
    trainer.extend(extensions.observe_value(
        'graph', lambda _: create_fig(args.dir)), trigger=(2, 'epoch'))
    trainer.extend(extensions.ProgressBar(update_interval=10))

    trainer.run()
Beispiel #8
0
def train(args):
    '''Run training'''
    # display chainer version
    logging.info('chainer version = ' + chainer.__version__)

    # seed setting (chainer seed may not need it)
    os.environ['CHAINER_SEED'] = str(args.seed)
    logging.info('chainer seed = ' + os.environ['CHAINER_SEED'])

    # debug mode setting
    # 0 would be fastest, but 1 seems to be reasonable
    # by considering reproducability
    # revmoe type check
    if args.debugmode < 2:
        chainer.config.type_check = False
        logging.info('chainer type check is disabled')
    # use determinisitic computation or not
    if args.debugmode < 1:
        chainer.config.cudnn_deterministic = False
        logging.info('chainer cudnn deterministic is disabled')
    else:
        chainer.config.cudnn_deterministic = True

    # check cuda and cudnn availability
    if not chainer.cuda.available:
        logging.warning('cuda is not available')
    if not chainer.cuda.cudnn_enabled:
        logging.warning('cudnn is not available')

    # get input and output dimension info
    with open(args.valid_json, 'rb') as f:
        valid_json = json.load(f)['utts']
    utts = list(valid_json.keys())
    # TODO(nelson) remove in future
    if 'input' not in valid_json[utts[0]]:
        logging.error("input file format (json) is modified, please redo"
                      "stage 2: Dictionary and Json Data Preparation")
        sys.exit(1)
    idim = int(valid_json[utts[0]]['input'][0]['shape'][1])
    odim = int(valid_json[utts[0]]['output'][0]['shape'][1])
    logging.info('#input dims : ' + str(idim))
    logging.info('#output dims: ' + str(odim))

    # check attention type
    if args.atype not in ['noatt', 'dot', 'location']:
        raise NotImplementedError(
            'chainer supports only noatt, dot, and location attention.')

    # specify attention, CTC, hybrid mode
    if args.mtlalpha == 1.0:
        mtl_mode = 'ctc'
        logging.info('Pure CTC mode')
    elif args.mtlalpha == 0.0:
        mtl_mode = 'att'
        logging.info('Pure attention mode')
    else:
        mtl_mode = 'mtl'
        logging.info('Multitask learning mode')

    # specify model architecture
    e2e = E2E(idim, odim, args)
    model = Loss(e2e, args.mtlalpha)

    # write model config
    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)
    model_conf = args.outdir + '/model.conf'
    with open(model_conf, 'wb') as f:
        logging.info('writing a model config file to' + model_conf)
        # TODO(watanabe) use others than pickle, possibly json, and save as a text
        pickle.dump((idim, odim, args), f)
    for key in sorted(vars(args).keys()):
        logging.info('ARGS: ' + key + ': ' + str(vars(args)[key]))

    # Set gpu
    ngpu = args.ngpu
    if ngpu == 1:
        gpu_id = 0
        # Make a specified GPU current
        chainer.cuda.get_device_from_id(gpu_id).use()
        model.to_gpu()  # Copy the model to the GPU
        logging.info('single gpu calculation.')
    elif ngpu > 1:
        gpu_id = 0
        devices = {'main': gpu_id}
        for gid in six.moves.xrange(1, ngpu):
            devices['sub_%d' % gid] = gid
        logging.info('multi gpu calculation (#gpus = %d).' % ngpu)
        logging.info('batch size is automatically increased (%d -> %d)' %
                     (args.batch_size, args.batch_size * args.ngpu))
    else:
        gpu_id = -1
        logging.info('cpu calculation')

    # Setup an optimizer
    if args.opt == 'adadelta':
        optimizer = chainer.optimizers.AdaDelta(eps=args.eps)
    elif args.opt == 'adam':
        optimizer = chainer.optimizers.Adam()
    optimizer.setup(model)
    optimizer.add_hook(chainer.optimizer.GradientClipping(args.grad_clip))

    # read json data
    with open(args.train_json, 'rb') as f:
        train_json = json.load(f)['utts']
    with open(args.valid_json, 'rb') as f:
        valid_json = json.load(f)['utts']

    # set up training iterator and updater
    if ngpu <= 1:
        # make minibatch list (variable length)
        train = make_batchset(train_json, args.batch_size, args.maxlen_in,
                              args.maxlen_out, args.minibatches)
        # hack to make batchsize argument as 1
        # actual batchsize is included in a list
        train_iter = chainer.iterators.SerialIterator(train, 1)

        # set up updater
        updater = ChainerSeqUpdaterKaldi(train_iter,
                                         optimizer,
                                         converter=converter_kaldi,
                                         device=gpu_id)
    else:
        # set up minibatches
        train_subsets = []
        for gid in six.moves.xrange(ngpu):
            # make subset
            train_json_subset = {
                k: v
                for i, (k, v) in enumerate(train_json.viewitems())
                if i % ngpu == gid
            }
            # make minibatch list (variable length)
            train_subsets += [
                make_batchset(train_json_subset, args.batch_size,
                              args.maxlen_in, args.maxlen_out,
                              args.minibatches)
            ]

        # each subset must have same length for MultiprocessParallelUpdater
        maxlen = max([len(train_subset) for train_subset in train_subsets])
        for train_subset in train_subsets:
            if maxlen != len(train_subset):
                for i in six.moves.xrange(maxlen - len(train_subset)):
                    train_subset += [train_subset[i]]

        # hack to make batchsize argument as 1
        # actual batchsize is included in a list
        train_iters = [
            chainer.iterators.MultiprocessIterator(train_subsets[gid],
                                                   1,
                                                   n_processes=1)
            for gid in six.moves.xrange(ngpu)
        ]

        # set up updater
        updater = ChainerMultiProcessParallelUpdaterKaldi(
            train_iters, optimizer, converter=converter_kaldi, devices=devices)

    # Set up a trainer
    trainer = training.Trainer(updater, (args.epochs, 'epoch'),
                               out=args.outdir)

    # Resume from a snapshot
    if args.resume:
        chainer.serializers.load_npz(args.resume, trainer)

    # set up validation iterator
    valid = make_batchset(valid_json, args.batch_size, args.maxlen_in,
                          args.maxlen_out, args.minibatches)
    valid_iter = chainer.iterators.SerialIterator(valid,
                                                  1,
                                                  repeat=False,
                                                  shuffle=False)

    # Evaluate the model with the test dataset for each epoch
    trainer.extend(
        extensions.Evaluator(valid_iter,
                             model,
                             converter=converter_kaldi,
                             device=gpu_id))

    # Save attention weight each epoch
    if args.num_save_attention > 0 and args.mtlalpha != 1.0:
        data = sorted(list(valid_json.items())[:args.num_save_attention],
                      key=lambda x: int(x[1]['input'][0]['shape'][1]),
                      reverse=True)
        data = converter_kaldi([data], device=gpu_id)
        trainer.extend(PlotAttentionReport(model, data,
                                           args.outdir + "/att_ws"),
                       trigger=(1, 'epoch'))

    # Take a snapshot for each specified epoch
    trainer.extend(extensions.snapshot(), trigger=(1, 'epoch'))

    # Make a plot for training and validation values
    trainer.extend(
        extensions.PlotReport([
            'main/loss', 'validation/main/loss', 'main/loss_ctc',
            'validation/main/loss_ctc', 'main/loss_att',
            'validation/main/loss_att'
        ],
                              'epoch',
                              file_name='loss.png'))
    trainer.extend(
        extensions.PlotReport(['main/acc', 'validation/main/acc'],
                              'epoch',
                              file_name='acc.png'))

    # Save best models
    trainer.extend(
        extensions.snapshot_object(model, 'model.loss.best'),
        trigger=training.triggers.MinValueTrigger('validation/main/loss'))
    if mtl_mode is not 'ctc':
        trainer.extend(
            extensions.snapshot_object(model, 'model.acc.best'),
            trigger=training.triggers.MaxValueTrigger('validation/main/acc'))

    # epsilon decay in the optimizer
    if args.opt == 'adadelta':
        if args.criterion == 'acc' and mtl_mode is not 'ctc':
            trainer.extend(restore_snapshot(model,
                                            args.outdir + '/model.acc.best'),
                           trigger=CompareValueTrigger(
                               'validation/main/acc', lambda best_value,
                               current_value: best_value > current_value))
            trainer.extend(adadelta_eps_decay(args.eps_decay),
                           trigger=CompareValueTrigger(
                               'validation/main/acc', lambda best_value,
                               current_value: best_value > current_value))
        elif args.criterion == 'loss':
            trainer.extend(restore_snapshot(model,
                                            args.outdir + '/model.loss.best'),
                           trigger=CompareValueTrigger(
                               'validation/main/loss', lambda best_value,
                               current_value: best_value < current_value))
            trainer.extend(adadelta_eps_decay(args.eps_decay),
                           trigger=CompareValueTrigger(
                               'validation/main/loss', lambda best_value,
                               current_value: best_value < current_value))

    # Write a log of evaluation statistics for each epoch
    trainer.extend(extensions.LogReport(trigger=(100, 'iteration')))
    report_keys = [
        'epoch', 'iteration', 'main/loss', 'main/loss_ctc', 'main/loss_att',
        'validation/main/loss', 'validation/main/loss_ctc',
        'validation/main/loss_att', 'main/acc', 'validation/main/acc',
        'elapsed_time'
    ]
    if args.opt == 'adadelta':
        trainer.extend(extensions.observe_value(
            'eps', lambda trainer: trainer.updater.get_optimizer('main').eps),
                       trigger=(100, 'iteration'))
        report_keys.append('eps')
    trainer.extend(extensions.PrintReport(report_keys),
                   trigger=(100, 'iteration'))

    trainer.extend(extensions.ProgressBar())

    # Run the training
    trainer.run()
Beispiel #9
0
def train(args):
    """Train with the given args.

    Args:
        args (namespace): The program arguments.

    """
    # display chainer version
    logging.info('chainer version = ' + chainer.__version__)

    set_deterministic_chainer(args)

    # check cuda and cudnn availability
    if not chainer.cuda.available:
        logging.warning('cuda is not available')
    if not chainer.cuda.cudnn_enabled:
        logging.warning('cudnn is not available')

    # get input and output dimension info
    with open(args.valid_json, 'rb') as f:
        valid_json = json.load(f)['utts']
    utts = list(valid_json.keys())
    idim = int(valid_json[utts[0]]['input'][0]['shape'][1])
    odim = int(valid_json[utts[0]]['output'][0]['shape'][1])
    logging.info('#input dims : ' + str(idim))
    logging.info('#output dims: ' + str(odim))

    # specify attention, CTC, hybrid mode
    if args.mtlalpha == 1.0:
        mtl_mode = 'ctc'
        logging.info('Pure CTC mode')
    elif args.mtlalpha == 0.0:
        mtl_mode = 'att'
        logging.info('Pure attention mode')
    else:
        mtl_mode = 'mtl'
        logging.info('Multitask learning mode')

    # specify model architecture
    logging.info('import model module: ' + args.model_module)
    model_class = dynamic_import(args.model_module)
    model = model_class(idim, odim, args, flag_return=False)
    assert isinstance(model, ASRInterface)

    # write model config
    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)
    model_conf = args.outdir + '/model.json'
    with open(model_conf, 'wb') as f:
        logging.info('writing a model config file to ' + model_conf)
        f.write(
            json.dumps((idim, odim, vars(args)),
                       indent=4,
                       ensure_ascii=False,
                       sort_keys=True).encode('utf_8'))
    for key in sorted(vars(args).keys()):
        logging.info('ARGS: ' + key + ': ' + str(vars(args)[key]))

    # Set gpu
    ngpu = args.ngpu
    if ngpu == 1:
        gpu_id = 0
        # Make a specified GPU current
        chainer.cuda.get_device_from_id(gpu_id).use()
        model.to_gpu()  # Copy the model to the GPU
        logging.info('single gpu calculation.')
    elif ngpu > 1:
        gpu_id = 0
        devices = {'main': gpu_id}
        for gid in six.moves.xrange(1, ngpu):
            devices['sub_%d' % gid] = gid
        logging.info('multi gpu calculation (#gpus = %d).' % ngpu)
        logging.info('batch size is automatically increased (%d -> %d)' %
                     (args.batch_size, args.batch_size * args.ngpu))
    else:
        gpu_id = -1
        logging.info('cpu calculation')

    # Setup an optimizer
    if args.opt == 'adadelta':
        optimizer = chainer.optimizers.AdaDelta(eps=args.eps)
    elif args.opt == 'adam':
        optimizer = chainer.optimizers.Adam()
    elif args.opt == 'noam':
        optimizer = chainer.optimizers.Adam(alpha=0,
                                            beta1=0.9,
                                            beta2=0.98,
                                            eps=1e-9)
    else:
        raise NotImplementedError('args.opt={}'.format(args.opt))

    optimizer.setup(model)
    optimizer.add_hook(chainer.optimizer.GradientClipping(args.grad_clip))

    # Setup Training Extensions
    if 'transformer' in args.model_module:
        from espnet.nets.chainer_backend.transformer.training import CustomConverter
        from espnet.nets.chainer_backend.transformer.training import CustomParallelUpdater
        from espnet.nets.chainer_backend.transformer.training import CustomUpdater
    else:
        from espnet.nets.chainer_backend.rnn.training import CustomConverter
        from espnet.nets.chainer_backend.rnn.training import CustomParallelUpdater
        from espnet.nets.chainer_backend.rnn.training import CustomUpdater

    # Setup a converter
    converter = CustomConverter(subsampling_factor=model.subsample[0])

    # read json data
    with open(args.train_json, 'rb') as f:
        train_json = json.load(f)['utts']
    with open(args.valid_json, 'rb') as f:
        valid_json = json.load(f)['utts']

    # set up training iterator and updater
    load_tr = LoadInputsAndTargets(
        mode='asr',
        load_output=True,
        preprocess_conf=args.preprocess_conf,
        preprocess_args={'train': True}  # Switch the mode of preprocessing
    )
    load_cv = LoadInputsAndTargets(
        mode='asr',
        load_output=True,
        preprocess_conf=args.preprocess_conf,
        preprocess_args={'train': False}  # Switch the mode of preprocessing
    )

    use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0
    accum_grad = args.accum_grad
    if ngpu <= 1:
        # make minibatch list (variable length)
        train = make_batchset(train_json,
                              args.batch_size,
                              args.maxlen_in,
                              args.maxlen_out,
                              args.minibatches,
                              min_batch_size=args.ngpu if args.ngpu > 1 else 1,
                              shortest_first=use_sortagrad,
                              count=args.batch_count,
                              batch_bins=args.batch_bins,
                              batch_frames_in=args.batch_frames_in,
                              batch_frames_out=args.batch_frames_out,
                              batch_frames_inout=args.batch_frames_inout,
                              iaxis=0,
                              oaxis=0)
        # hack to make batchsize argument as 1
        # actual batchsize is included in a list
        if args.n_iter_processes > 0:
            train_iters = [
                ToggleableShufflingMultiprocessIterator(
                    TransformDataset(train, load_tr),
                    batch_size=1,
                    n_processes=args.n_iter_processes,
                    n_prefetch=8,
                    maxtasksperchild=20,
                    shuffle=not use_sortagrad)
            ]
        else:
            train_iters = [
                ToggleableShufflingSerialIterator(TransformDataset(
                    train, load_tr),
                                                  batch_size=1,
                                                  shuffle=not use_sortagrad)
            ]

        # set up updater
        updater = CustomUpdater(train_iters[0],
                                optimizer,
                                converter=converter,
                                device=gpu_id,
                                accum_grad=accum_grad)
    else:
        if args.batch_count not in ("auto", "seq") and args.batch_size == 0:
            raise NotImplementedError(
                "--batch-count 'bin' and 'frame' are not implemented in chainer multi gpu"
            )
        # set up minibatches
        train_subsets = []
        for gid in six.moves.xrange(ngpu):
            # make subset
            train_json_subset = {
                k: v
                for i, (k, v) in enumerate(train_json.items())
                if i % ngpu == gid
            }
            # make minibatch list (variable length)
            train_subsets += [
                make_batchset(train_json_subset, args.batch_size,
                              args.maxlen_in, args.maxlen_out,
                              args.minibatches)
            ]

        # each subset must have same length for MultiprocessParallelUpdater
        maxlen = max([len(train_subset) for train_subset in train_subsets])
        for train_subset in train_subsets:
            if maxlen != len(train_subset):
                for i in six.moves.xrange(maxlen - len(train_subset)):
                    train_subset += [train_subset[i]]

        # hack to make batchsize argument as 1
        # actual batchsize is included in a list
        if args.n_iter_processes > 0:
            train_iters = [
                ToggleableShufflingMultiprocessIterator(
                    TransformDataset(train_subsets[gid], load_tr),
                    batch_size=1,
                    n_processes=args.n_iter_processes,
                    n_prefetch=8,
                    maxtasksperchild=20,
                    shuffle=not use_sortagrad)
                for gid in six.moves.xrange(ngpu)
            ]
        else:
            train_iters = [
                ToggleableShufflingSerialIterator(TransformDataset(
                    train_subsets[gid], load_tr),
                                                  batch_size=1,
                                                  shuffle=not use_sortagrad)
                for gid in six.moves.xrange(ngpu)
            ]

        # set up updater
        updater = CustomParallelUpdater(train_iters,
                                        optimizer,
                                        converter=converter,
                                        devices=devices)

    # Set up a trainer
    trainer = training.Trainer(updater, (args.epochs, 'epoch'),
                               out=args.outdir)

    if use_sortagrad:
        trainer.extend(
            ShufflingEnabler(train_iters),
            trigger=(args.sortagrad if args.sortagrad != -1 else args.epochs,
                     'epoch'))
    if args.opt == 'noam':
        from espnet.nets.chainer_backend.transformer.training import VaswaniRule
        trainer.extend(VaswaniRule('alpha',
                                   d=args.adim,
                                   warmup_steps=args.transformer_warmup_steps,
                                   scale=args.transformer_lr),
                       trigger=(1, 'iteration'))
    # Resume from a snapshot
    if args.resume:
        chainer.serializers.load_npz(args.resume, trainer)

    # set up validation iterator
    valid = make_batchset(valid_json,
                          args.batch_size,
                          args.maxlen_in,
                          args.maxlen_out,
                          args.minibatches,
                          min_batch_size=args.ngpu if args.ngpu > 1 else 1,
                          count=args.batch_count,
                          batch_bins=args.batch_bins,
                          batch_frames_in=args.batch_frames_in,
                          batch_frames_out=args.batch_frames_out,
                          batch_frames_inout=args.batch_frames_inout,
                          iaxis=0,
                          oaxis=0)

    if args.n_iter_processes > 0:
        valid_iter = chainer.iterators.MultiprocessIterator(
            TransformDataset(valid, load_cv),
            batch_size=1,
            repeat=False,
            shuffle=False,
            n_processes=args.n_iter_processes,
            n_prefetch=8,
            maxtasksperchild=20)
    else:
        valid_iter = chainer.iterators.SerialIterator(TransformDataset(
            valid, load_cv),
                                                      batch_size=1,
                                                      repeat=False,
                                                      shuffle=False)

    # Evaluate the model with the test dataset for each epoch
    trainer.extend(
        BaseEvaluator(valid_iter, model, converter=converter, device=gpu_id))

    # Save attention weight each epoch
    if args.num_save_attention > 0 and args.mtlalpha != 1.0:
        data = sorted(list(valid_json.items())[:args.num_save_attention],
                      key=lambda x: int(x[1]['input'][0]['shape'][1]),
                      reverse=True)
        if hasattr(model, "module"):
            att_vis_fn = model.module.calculate_all_attentions
            plot_class = model.module.attention_plot_class
        else:
            att_vis_fn = model.calculate_all_attentions
            plot_class = model.attention_plot_class
        logging.info('Using custom PlotAttentionReport')
        att_reporter = plot_class(att_vis_fn,
                                  data,
                                  args.outdir + "/att_ws",
                                  converter=converter,
                                  transform=load_cv,
                                  device=gpu_id)
        trainer.extend(att_reporter, trigger=(1, 'epoch'))
    else:
        att_reporter = None

    # Take a snapshot for each specified epoch
    trainer.extend(
        extensions.snapshot(filename='snapshot.ep.{.updater.epoch}'),
        trigger=(1, 'epoch'))

    # Make a plot for training and validation values
    trainer.extend(
        extensions.PlotReport([
            'main/loss', 'validation/main/loss', 'main/loss_ctc',
            'validation/main/loss_ctc', 'main/loss_att',
            'validation/main/loss_att'
        ],
                              'epoch',
                              file_name='loss.png'))
    trainer.extend(
        extensions.PlotReport(['main/acc', 'validation/main/acc'],
                              'epoch',
                              file_name='acc.png'))

    # Save best models
    trainer.extend(
        extensions.snapshot_object(model, 'model.loss.best'),
        trigger=training.triggers.MinValueTrigger('validation/main/loss'))
    if mtl_mode != 'ctc':
        trainer.extend(
            extensions.snapshot_object(model, 'model.acc.best'),
            trigger=training.triggers.MaxValueTrigger('validation/main/acc'))

    # epsilon decay in the optimizer
    if args.opt == 'adadelta':
        if args.criterion == 'acc' and mtl_mode != 'ctc':
            trainer.extend(restore_snapshot(model,
                                            args.outdir + '/model.acc.best'),
                           trigger=CompareValueTrigger(
                               'validation/main/acc', lambda best_value,
                               current_value: best_value > current_value))
            trainer.extend(adadelta_eps_decay(args.eps_decay),
                           trigger=CompareValueTrigger(
                               'validation/main/acc', lambda best_value,
                               current_value: best_value > current_value))
        elif args.criterion == 'loss':
            trainer.extend(restore_snapshot(model,
                                            args.outdir + '/model.loss.best'),
                           trigger=CompareValueTrigger(
                               'validation/main/loss', lambda best_value,
                               current_value: best_value < current_value))
            trainer.extend(adadelta_eps_decay(args.eps_decay),
                           trigger=CompareValueTrigger(
                               'validation/main/loss', lambda best_value,
                               current_value: best_value < current_value))

    # Write a log of evaluation statistics for each epoch
    trainer.extend(
        extensions.LogReport(trigger=(args.report_interval_iters,
                                      'iteration')))
    report_keys = [
        'epoch', 'iteration', 'main/loss', 'main/loss_ctc', 'main/loss_att',
        'validation/main/loss', 'validation/main/loss_ctc',
        'validation/main/loss_att', 'main/acc', 'validation/main/acc',
        'elapsed_time'
    ]
    if args.opt == 'adadelta':
        trainer.extend(extensions.observe_value(
            'eps', lambda trainer: trainer.updater.get_optimizer('main').eps),
                       trigger=(args.report_interval_iters, 'iteration'))
        report_keys.append('eps')
    trainer.extend(extensions.PrintReport(report_keys),
                   trigger=(args.report_interval_iters, 'iteration'))

    trainer.extend(
        extensions.ProgressBar(update_interval=args.report_interval_iters))

    set_early_stop(trainer, args)
    if args.tensorboard_dir is not None and args.tensorboard_dir != "":
        writer = SummaryWriter(args.tensorboard_dir)
        trainer.extend(TensorboardLogger(writer, att_reporter),
                       trigger=(args.report_interval_iters, 'iteration'))

    # Run the training
    trainer.run()
    check_early_stop(trainer, args.epochs)
Beispiel #10
0
def main():
    processors = {
        "cola": ColaProcessor,
        "mnli": MnliProcessor,
        "mrpc": MrpcProcessor,
    }

    if not FLAGS.do_train and not FLAGS.do_eval and not FLAGS.do_print_test:
        raise ValueError("At least one of `do_train` or `do_eval` "
                         "or `do_print_test` must be True.")

    bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)

    if FLAGS.max_seq_length > bert_config.max_position_embeddings:
        raise ValueError(
            "Cannot use sequence length %d because the BERT model "
            "was only trained up to sequence length %d" %
            (FLAGS.max_seq_length, bert_config.max_position_embeddings))

    if not os.path.isdir(FLAGS.output_dir):
        os.makedirs(FLAGS.output_dir)

    task_name = FLAGS.task_name.lower()

    if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))

    processor = processors[task_name]()

    label_list = processor.get_labels()

    tokenizer = tokenization.FullTokenizer(
        vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)

    train_examples = None
    num_train_steps = None
    num_warmup_steps = None

    # TODO: use special Adam from "optimization.py"
    if FLAGS.do_train:
        train_examples = processor.get_train_examples(FLAGS.data_dir)
        num_train_steps = int(
            len(train_examples) / FLAGS.train_batch_size * FLAGS.num_train_epochs)
        num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion)

    bert = modeling.BertModel(config=bert_config)
    model = modeling.BertClassifier(bert, num_labels=len(label_list))
    chainer.serializers.load_npz(
        FLAGS.init_checkpoint, model,
        ignore_names=['output/W', 'output/b'])

    if FLAGS.gpu >= 0:
        chainer.backends.cuda.get_device_from_id(FLAGS.gpu).use()
        model.to_gpu()

    if FLAGS.do_train:
        # Adam with weight decay only for 2D matrices
        optimizer = optimization.WeightDecayForMatrixAdam(
            alpha=1.,  # ignore alpha. instead, use eta as actual lr
            eps=1e-6, weight_decay_rate=0.01)
        optimizer.setup(model)
        optimizer.add_hook(chainer.optimizer.GradientClipping(1.))

        train_iter = chainer.iterators.SerialIterator(
            train_examples, FLAGS.train_batch_size)
        converter = Converter(
            label_list, FLAGS.max_seq_length, tokenizer)
        updater = training.updaters.StandardUpdater(
            train_iter, optimizer,
            converter=converter,
            device=FLAGS.gpu)
        trainer = training.Trainer(
            updater, (num_train_steps, 'iteration'), out=FLAGS.output_dir)

        # learning rate (eta) scheduling in Adam
        lr_decay_init = FLAGS.learning_rate * \
            (num_train_steps - num_warmup_steps) / num_train_steps
        trainer.extend(extensions.LinearShift(  # decay
            'eta', (lr_decay_init, 0.), (num_warmup_steps, num_train_steps)))
        trainer.extend(extensions.WarmupShift(  # warmup
            'eta', 0., num_warmup_steps, FLAGS.learning_rate))
        trainer.extend(extensions.observe_value(
            'eta', lambda trainer: trainer.updater.get_optimizer('main').eta),
            trigger=(50, 'iteration'))  # logging

        trainer.extend(extensions.snapshot_object(
            model, 'model_snapshot_iter_{.updater.iteration}.npz'),
            trigger=(num_train_steps, 'iteration'))
        trainer.extend(extensions.LogReport(
            trigger=(50, 'iteration')))
        trainer.extend(extensions.PrintReport(
            ['iteration', 'main/loss',
             'main/accuracy', 'elapsed_time']))
        trainer.extend(extensions.ProgressBar(update_interval=10))

        trainer.run()

    if FLAGS.do_eval:
        eval_examples = processor.get_dev_examples(FLAGS.data_dir)
        test_iter = chainer.iterators.SerialIterator(
            eval_examples, FLAGS.train_batch_size * 2,
            repeat=False, shuffle=False)
        converter = Converter(
            label_list, FLAGS.max_seq_length, tokenizer)
        evaluator = extensions.Evaluator(
            test_iter, model, converter=converter, device=FLAGS.gpu)
        results = evaluator()
        print(results)

    # if you wanna see some output arrays for debugging
    if FLAGS.do_print_test:
        short_eval_examples = processor.get_dev_examples(FLAGS.data_dir)[:3]
        short_eval_examples = short_eval_examples[:FLAGS.eval_batch_size]
        short_test_iter = chainer.iterators.SerialIterator(
            short_eval_examples, FLAGS.eval_batch_size,
            repeat=False, shuffle=False)
        converter = Converter(
            label_list, FLAGS.max_seq_length, tokenizer)
        evaluator = extensions.Evaluator(
            test_iter, model, converter=converter, device=FLAGS.gpu)

        with chainer.using_config('train', False):
            with chainer.no_backprop_mode():
                data = short_test_iter.__next__()
                out = model.bert.get_pooled_output(
                    *converter(data, FLAGS.gpu)[:-1])
                print(out)
                print(out.shape)
            print(converter(data, -1))
Beispiel #11
0
def main():
    if not FLAGS.do_train and not FLAGS.do_eval:
        raise ValueError(
            "At least one of `do_train` or `do_eval` must be True.")

    bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)

    if not os.path.isdir(FLAGS.output_dir):
        os.makedirs(FLAGS.output_dir)

    def _get_text_file(text_dir):
        import glob
        #file_list = glob.glob(f'{text_dir}/**/*')
        # seqが512
        #file_list = ['/nfs/ai16storage01/sec/akp2/1706nasubi/inatomi/benchmark/bert-chainer/data/wiki_data_pickle/all']
        # seqが128
        file_list = ['/nfs/ai16storage01/sec/akp2/1706nasubi/inatomi/benchmark/bert-chainer/data/wiki_data_pickle/all_seq128']
        # debug
        #file_list = ['/nfs/ai16storage01/sec/akp2/1706nasubi/inatomi/benchmark/bert-chainer/data/wiki_data_pickle/AA/wiki_00']
        files = ",".join(file_list)
        return files
    input_files = _get_text_file(FLAGS.input_file).split(',')

   #  model_fn = model_fn_builder(
   #      bert_config=bert_config,
   #      init_checkpoint=FLAGS.init_checkpoint,
   #      learning_rate=FLAGS.learning_rate,
   #      num_train_steps=FLAGS.num_train_steps,
   #      num_warmup_steps=FLAGS.num_warmup_steps,
   #      use_tpu=FLAGS.use_tpu,
   #      use_one_hot_embeddings=FLAGS.use_tpu)

    if FLAGS.do_train:
        input_files = input_files
    bert = modeling.BertModel(config=bert_config)
    model = modeling.BertPretrainer(bert)
    if FLAGS.init_checkpoint:
        serializers.load_npz(FLAGS.init_checkpoint, model)
        model = modeling.BertPretrainer(model.bert)
    if FLAGS.gpu >= 0:
        pass
        #chainer.backends.cuda.get_device_from_id(FLAGS.gpu).use()
        #model.to_gpu()

    if FLAGS.do_train:
        """chainerでのpretrainを記述。BERTClassificationに変わるものを作成し、BERTの出力をこねこねしてmodel_fnが返すものと同じものを返すようにすれば良いか?"""
        # Adam with weight decay only for 2D matrices
        optimizer = optimization.WeightDecayForMatrixAdam(
            alpha=1.,  # ignore alpha. instead, use eta as actual lr
            eps=1e-6, weight_decay_rate=0.01)
        optimizer.setup(model)
        optimizer.add_hook(chainer.optimizer.GradientClipping(1.))

        """ ConcatenatedDatasetはon memolyなため、巨大データセットのPickleを扱えない
        input_files = sorted(input_files)[:len(input_files) // 2]
        input_files = sorted(input_files)[:200]
        import concurrent.futures
        train_examples = []
        with concurrent.futures.ThreadPoolExecutor() as executor:
            for train_exapmle in executor.map(_load_data_using_dataset_api, input_files):
                train_examples.append(train_exapmle)
        train_examples = ConcatenatedDataset(*train_examples)
        """
        train_examples = _load_data_using_dataset_api(input_files[0])

        train_iter = chainer.iterators.SerialIterator(
            train_examples, FLAGS.train_batch_size)
        converter = Converter()
        if False:
            updater = training.updaters.StandardUpdater(
                train_iter, optimizer,
                converter=converter,
                device=FLAGS.gpu)
        else:
            updater = training.updaters.ParallelUpdater(
                iterator=train_iter,
                optimizer=optimizer,
                converter=converter,
                # The device of the name 'main' is used as a "master", while others are
                # used as slaves. Names other than 'main' are arbitrary.
                devices={'main': 0,
                         '1': 1,
                         '2': 2,
                         '3': 3,
                         '4': 4,
                         '5': 5,
                         '6': 6,
                         '7': 7,
                         },
            )
        # learning rate (eta) scheduling in Adam
        num_warmup_steps = FLAGS.num_warmup_steps
        num_train_steps = FLAGS.num_train_steps
        trainer = training.Trainer(
            updater, (num_train_steps, 'iteration'), out=FLAGS.output_dir)
        lr_decay_init = FLAGS.learning_rate * \
            (num_train_steps - num_warmup_steps) / num_train_steps
        trainer.extend(extensions.LinearShift(  # decay
            'eta', (lr_decay_init, 0.), (num_warmup_steps, num_train_steps)))
        trainer.extend(extensions.WarmupShift(  # warmup
            'eta', 0., num_warmup_steps, FLAGS.learning_rate))
        trainer.extend(extensions.observe_value(
            'eta', lambda trainer: trainer.updater.get_optimizer('main').eta),
            trigger=(50, 'iteration'))  # logging

        trainer.extend(extensions.snapshot_object(
            model, 'seq_128_model_snapshot_iter_{.updater.iteration}.npz'),
            trigger=(1000, 'iteration'))
        trainer.extend(extensions.LogReport(
            trigger=(1, 'iteration')))
        #trainer.extend(extensions.PlotReport(
        #    [
        #        'main/next_sentence_loss',
        #        'main/next_sentence_accuracy',
        #     ], (3, 'iteration'), file_name='next_sentence.png'))
        #trainer.extend(extensions.PlotReport(
        #    [
        #        'main/masked_lm_loss',
        #        'main/masked_lm_accuracy',
        #     ], (3, 'iteration'), file_name='masked_lm.png'))
        trainer.extend(extensions.PlotReport(
            y_keys=[
                'main/loss',
                'main/next_sentence_loss',
                'main/next_sentence_accuracy',
                'main/masked_lm_loss',
                'main/masked_lm_accuracy',
             ], x_key='iteration', trigger=(100, 'iteration'), file_name='loss.png'))
        trainer.extend(extensions.PrintReport(
            ['iteration',
             'main/loss',
             'main/masked_lm_loss', 'main/masked_lm_accuracy',
             'main/next_sentence_loss', 'main/next_sentence_accuracy',
             'elapsed_time']))
        trainer.extend(extensions.ProgressBar(update_interval=20))

        trainer.run()

    if FLAGS.do_eval:
        tf.logging.info("***** Running evaluation *****")
        tf.logging.info("  Batch size = %d", FLAGS.eval_batch_size)

        eval_input_fn = input_fn_builder(
            input_files=input_files,
            max_seq_length=FLAGS.max_seq_length,
            max_predictions_per_seq=FLAGS.max_predictions_per_seq,
            is_training=False)

        result = estimator.evaluate(
            input_fn=eval_input_fn, steps=FLAGS.max_eval_steps)

        output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt")
        with tf.gfile.GFile(output_eval_file, "w") as writer:
            tf.logging.info("***** Eval results *****")
            for key in sorted(result.keys()):
                tf.logging.info("  %s = %s", key, str(result[key]))
                writer.write("%s = %s\n" % (key, str(result[key])))
def main():
    processors = {
        "cola": ColaProcessor,
        "mnli": MnliProcessor,
        "mrpc": MrpcProcessor,
        "livedoor": LivedoorProcessor,
    }

    if not FLAGS.do_train and not FLAGS.do_eval and not FLAGS.do_print_test:
        raise ValueError("At least one of `do_train` or `do_eval` "
                         "or `do_print_test` must be True.")

    bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)

    if FLAGS.max_seq_length > bert_config.max_position_embeddings:
        raise ValueError(
            "Cannot use sequence length %d because the BERT model "
            "was only trained up to sequence length %d" %
            (FLAGS.max_seq_length, bert_config.max_position_embeddings))

    if not os.path.isdir(FLAGS.output_dir):
        os.makedirs(FLAGS.output_dir)

    task_name = FLAGS.task_name.lower()

    if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))

    processor = processors[task_name]()

    label_list = processor.get_labels()

    tokenizer = tokenization.FullTokenizer(model_file=FLAGS.model_file,
                                           vocab_file=FLAGS.vocab_file,
                                           do_lower_case=FLAGS.do_lower_case)

    train_examples = None
    num_train_steps = None
    num_warmup_steps = None

    # TODO: use special Adam from "optimization.py"
    if FLAGS.do_train:
        train_examples = processor.get_train_examples(FLAGS.data_dir)
        num_train_steps = int(
            len(train_examples) / FLAGS.train_batch_size *
            FLAGS.num_train_epochs)
        num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion)

    bert = modeling.BertModel(config=bert_config)
    pretrained = modeling.BertPretrainer(bert)
    chainer.serializers.load_npz(FLAGS.init_checkpoint, pretrained)

    model = modeling.BertClassifier(pretrained.bert,
                                    num_labels=len(label_list))

    if FLAGS.gpu >= 0:
        chainer.backends.cuda.get_device_from_id(FLAGS.gpu).use()
        model.to_gpu()

    if FLAGS.do_train:
        # Adam with weight decay only for 2D matrices
        optimizer = optimization.WeightDecayForMatrixAdam(
            alpha=1.,  # ignore alpha. instead, use eta as actual lr
            eps=1e-6,
            weight_decay_rate=0.01)
        optimizer.setup(model)
        optimizer.add_hook(chainer.optimizer.GradientClipping(1.))

        train_iter = chainer.iterators.SerialIterator(train_examples,
                                                      FLAGS.train_batch_size)
        converter = Converter(label_list, FLAGS.max_seq_length, tokenizer)
        updater = training.updaters.StandardUpdater(train_iter,
                                                    optimizer,
                                                    converter=converter,
                                                    device=FLAGS.gpu)
        trainer = training.Trainer(updater, (num_train_steps, 'iteration'),
                                   out=FLAGS.output_dir)

        # learning rate (eta) scheduling in Adam
        lr_decay_init = FLAGS.learning_rate * \
            (num_train_steps - num_warmup_steps) / num_train_steps
        trainer.extend(
            extensions.LinearShift(  # decay
                'eta', (lr_decay_init, 0.),
                (num_warmup_steps, num_train_steps)))
        trainer.extend(
            extensions.WarmupShift(  # warmup
                'eta', 0., num_warmup_steps, FLAGS.learning_rate))
        trainer.extend(extensions.observe_value(
            'eta', lambda trainer: trainer.updater.get_optimizer('main').eta),
                       trigger=(50, 'iteration'))  # logging

        trainer.extend(extensions.snapshot_object(
            model, 'model_snapshot_iter_{.updater.iteration}.npz'),
                       trigger=(num_train_steps, 'iteration'))
        trainer.extend(extensions.LogReport(trigger=(50, 'iteration')))
        trainer.extend(
            extensions.PrintReport(
                ['iteration', 'main/loss', 'main/accuracy', 'elapsed_time']))
        trainer.extend(extensions.ProgressBar(update_interval=10))

        trainer.run()

    if FLAGS.do_eval:
        eval_examples = processor.get_dev_examples(FLAGS.data_dir)
        test_iter = chainer.iterators.SerialIterator(eval_examples,
                                                     FLAGS.train_batch_size *
                                                     2,
                                                     repeat=False,
                                                     shuffle=False)
        converter = Converter(label_list, FLAGS.max_seq_length, tokenizer)
        evaluator = extensions.Evaluator(test_iter,
                                         model,
                                         converter=converter,
                                         device=FLAGS.gpu)
        results = evaluator()
        print(results)

    # if you wanna see some output arrays for debugging
    if FLAGS.do_print_test:
        short_eval_examples = processor.get_dev_examples(FLAGS.data_dir)[:3]
        short_eval_examples = short_eval_examples[:FLAGS.eval_batch_size]
        short_test_iter = chainer.iterators.SerialIterator(
            short_eval_examples,
            FLAGS.eval_batch_size,
            repeat=False,
            shuffle=False)
        converter = Converter(label_list, FLAGS.max_seq_length, tokenizer)
        evaluator = extensions.Evaluator(test_iter,
                                         model,
                                         converter=converter,
                                         device=FLAGS.gpu)

        with chainer.using_config('train', False):
            with chainer.no_backprop_mode():
                data = short_test_iter.__next__()
                out = model.bert.get_pooled_output(
                    *converter(data, FLAGS.gpu)[:-1])
                print(out)
                print(out.shape)
            print(converter(data, -1))
def observe_pr(iterator_name='main', observation_key='pr'):
    return extensions.observe_value(
        observation_key,
        lambda trainer: trainer.updater._iterators['main'].p_batch_ratio)
Beispiel #14
0
def train(args):
    """Train with the given args.

    Args:
        args (namespace): The program arguments.

    """
    set_deterministic_pytorch(args)
    if args.num_encs > 1:
        args = format_mulenc_args(args)

    # check cuda availability
    if not torch.cuda.is_available():
        logging.warning("cuda is not available")

    # get input and output dimension info
    with open(args.valid_json, "rb") as f:
        valid_json = json.load(f)["utts"]
    utts = list(valid_json.keys())
    idim_list = [
        int(valid_json[utts[0]]["input"][i]["shape"][-1])
        for i in range(args.num_encs)
    ]
    odim = int(valid_json[utts[0]]["output"][0]["shape"][-1])
    for i in range(args.num_encs):
        logging.info("stream{}: input dims : {}".format(i + 1, idim_list[i]))
    logging.info("#output dims: " + str(odim))

    # specify semi-supervised method
    assert 0.0 <= args.mixup_alpha <= 1.0, "mixup-alpha should be [0.0, 1.0]"
    if args.mixup_alpha == 0.0:
        semi_mode = "MT"
        logging.info("Pure Mean-Teacher mode")
    else:
        semi_mode = "ICT"
        logging.info("Interpolation Consistency Training mode")

    if (args.enc_init is not None
            or args.dec_init is not None) and args.num_encs == 1:
        model = load_trained_modules(idim_list[0], odim, args)
    else:
        model_class = dynamic_import(args.model_module)
        model = model_class(idim_list[0] if args.num_encs == 1 else idim_list,
                            odim, args)

    assert isinstance(model, ASRInterface)

    if args.rnnlm is not None:
        rnnlm_args = get_model_conf(args.rnnlm, args.rnnlm_conf)
        rnnlm = lm_pytorch.ClassifierWithState(
            lm_pytorch.RNNLM(len(args.char_list), rnnlm_args.layer,
                             rnnlm_args.unit))
        torch_load(args.rnnlm, rnnlm)
        model.rnnlm = rnnlm

    # write model config
    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)
    model_conf = args.outdir + "/model.json"
    with open(model_conf, "wb") as f:
        logging.info("writing a model config file to " + model_conf)
        f.write(
            json.dumps(
                (idim_list[0] if args.num_encs == 1 else idim_list, odim,
                 vars(args)),
                indent=4,
                ensure_ascii=False,
                sort_keys=True,
            ).encode("utf_8"))
    for key in sorted(vars(args).keys()):
        logging.info("ARGS: " + key + ": " + str(vars(args)[key]))

    reporter = model.reporter

    # check the use of multi-gpu
    if args.ngpu > 1:
        if args.batch_size != 0:
            logging.warning(
                "batch size is automatically increased (%d -> %d)" %
                (args.batch_size, args.batch_size * args.ngpu))
            args.batch_size *= args.ngpu
        if args.num_encs > 1:
            # TODO(ruizhili): implement data parallel for multi-encoder setup.
            raise NotImplementedError(
                "Data parallel is not supported for multi-encoder setup.")

    # set torch device
    device = torch.device("cuda" if args.ngpu > 0 else "cpu")
    if args.train_dtype in ("float16", "float32", "float64"):
        dtype = getattr(torch, args.train_dtype)
    else:
        dtype = torch.float32
    model = model.to(device=device, dtype=dtype)

    # Setup an optimizer
    if args.opt == "adadelta":
        optimizer = torch.optim.Adadelta(model.enc.parameters(),
                                         rho=0.95,
                                         eps=args.eps,
                                         weight_decay=args.weight_decay)
    elif args.opt == "adam":
        optimizer = torch.optim.Adam(model.enc.parameters(),
                                     weight_decay=args.weight_decay)
    elif args.opt == "noam":
        from espnet.nets.pytorch_backend.transformer.optimizer import get_std_opt

        optimizer = get_std_opt(model.enc, args.adim,
                                args.transformer_warmup_steps,
                                args.transformer_lr)
    elif args.opt == "rmsprop":
        optimizer = torch.optim.RMSprop(model.enc.parameters(),
                                        lr=0.0008,
                                        alpha=0.95)
    elif args.opt == "sgd":
        optimizer = torch.optim.SGD(model.enc.parameters(),
                                    lr=0.5,
                                    momentum=0.9,
                                    nesterov=True)
    else:
        raise NotImplementedError("unknown optimizer: " + args.opt)

    use_apex = False

    # FIXME: TOO DIRTY HACK
    setattr(optimizer, "target", reporter)
    setattr(optimizer, "serialize", lambda s: reporter.serialize(s))

    # Setup a converter
    if args.num_encs == 1:
        converter = CustomConverter(subsampling_factor=model.subsample[0],
                                    dtype=dtype)
    else:
        converter = CustomConverterMulEnc([i[0] for i in model.subsample_list],
                                          dtype=dtype)

    # read json data
    assert 0.0 < args.utt_using_ratio < 1.0, "utt-using-ratio should not be 0 or 1"

    with open(args.valid_json, 'rb') as f:
        valid_json = json.load(f)['utts']

    if args.train_json is not None:
        with open(args.train_json, 'rb') as f:
            train_json = json.load(f)['utts']

        train_labeled_json_path = args.train_json.replace(
            '.json',
            '_labeled_{}.json'.format(int(args.utt_using_ratio * 100)))
        train_unlabeled_json_path = args.train_json.replace(
            '.json',
            '_unlabeled_{}.json'.format(100 - int(args.utt_using_ratio * 100)))
        if os.path.exists(train_labeled_json_path):
            with open(train_labeled_json_path, 'rb') as f:
                train_labeled_json = json.load(f)['utts']
            with open(train_unlabeled_json_path, 'rb') as f:
                train_unlabeled_json = json.load(f)['utts']
        else:
            # split json for each task
            split_point = [int(len(train_json) * args.utt_using_ratio)]
            train_labeled_json = dict(
                list(train_json.items())[:split_point[0]])
            train_unlabeled_json = dict(
                list(train_json.items())[split_point[0]:])
            with codecs.open(train_labeled_json_path, 'w+',
                             encoding='utf8') as f:
                json.dump({'utts': train_labeled_json},
                          f,
                          indent=4,
                          sort_keys=True,
                          ensure_ascii=False,
                          separators=(',', ': '))
            with codecs.open(train_unlabeled_json_path, 'w+',
                             encoding='utf8') as f:
                json.dump({'utts': train_unlabeled_json},
                          f,
                          indent=4,
                          sort_keys=True,
                          ensure_ascii=False,
                          separators=(',', ': '))
    else:
        with open(args.label_train_json, 'rb') as f:
            train_labeled_json = json.load(f)['utts']
        with open(args.unlabel_train_json, 'rb') as f:
            train_unlabeled_json = json.load(f)['utts']

    valid_labeled_json_path = args.valid_json.replace(
        '.json', '_labeled_{}.json'.format(int(args.utt_using_ratio * 100)))
    valid_unlabeled_json_path = args.valid_json.replace(
        '.json',
        '_unlabeled_{}.json'.format(100 - int(args.utt_using_ratio * 100)))
    if os.path.exists(valid_labeled_json_path):
        with open(valid_labeled_json_path, 'rb') as f:
            valid_labeled_json = json.load(f)['utts']
        with open(valid_unlabeled_json_path, 'rb') as f:
            valid_unlabeled_json = json.load(f)['utts']
    else:
        # split json for each task
        split_point = [int(len(valid_json) * args.utt_using_ratio)]
        valid_labeled_json = dict(list(valid_json.items())[:split_point[0]])
        valid_unlabeled_json = dict(list(valid_json.items())[split_point[0]:])
        with codecs.open(valid_labeled_json_path, 'w+', encoding='utf8') as f:
            json.dump({'utts': valid_labeled_json},
                      f,
                      indent=4,
                      sort_keys=True,
                      ensure_ascii=False,
                      separators=(',', ': '))
        with codecs.open(valid_unlabeled_json_path, 'w+',
                         encoding='utf8') as f:
            json.dump({'utts': valid_unlabeled_json},
                      f,
                      indent=4,
                      sort_keys=True,
                      ensure_ascii=False,
                      separators=(',', ': '))

    use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0
    # make minibatch list (variable length)
    train = make_batchset(train_labeled_json,
                          args.batch_size,
                          args.maxlen_in,
                          args.maxlen_out,
                          args.minibatches,
                          min_batch_size=args.ngpu if args.ngpu > 1 else 1,
                          shortest_first=use_sortagrad,
                          count=args.batch_count,
                          batch_bins=args.batch_bins,
                          batch_frames_in=args.batch_frames_in,
                          batch_frames_out=args.batch_frames_out,
                          batch_frames_inout=args.batch_frames_inout,
                          iaxis=0,
                          oaxis=0)
    valid = make_batchset(valid_labeled_json,
                          args.batch_size,
                          args.maxlen_in,
                          args.maxlen_out,
                          args.minibatches,
                          min_batch_size=args.ngpu if args.ngpu > 1 else 1,
                          count=args.batch_count,
                          batch_bins=args.batch_bins,
                          batch_frames_in=args.batch_frames_in,
                          batch_frames_out=args.batch_frames_out,
                          batch_frames_inout=args.batch_frames_inout,
                          iaxis=0,
                          oaxis=0)
    ul_train = make_batchset(train_unlabeled_json,
                             args.batch_size,
                             args.maxlen_in,
                             args.maxlen_out,
                             args.minibatches,
                             min_batch_size=args.ngpu if args.ngpu > 1 else 1,
                             count=args.batch_count,
                             batch_bins=args.batch_bins,
                             batch_frames_in=args.batch_frames_in,
                             batch_frames_out=args.batch_frames_out,
                             batch_frames_inout=args.batch_frames_inout,
                             iaxis=0,
                             oaxis=0)
    ul_valid = make_batchset(valid_unlabeled_json,
                             args.batch_size,
                             args.maxlen_in,
                             args.maxlen_out,
                             args.minibatches,
                             min_batch_size=args.ngpu if args.ngpu > 1 else 1,
                             count=args.batch_count,
                             batch_bins=args.batch_bins,
                             batch_frames_in=args.batch_frames_in,
                             batch_frames_out=args.batch_frames_out,
                             batch_frames_inout=args.batch_frames_inout,
                             iaxis=0,
                             oaxis=0)
    load_tr = LoadInputsAndTargets(
        mode='asr',
        load_output=True,
        preprocess_conf=args.preprocess_conf,
        preprocess_args={'train': True}  # Switch the mode of preprocessing
    )
    load_cv = LoadInputsAndTargets(
        mode='asr',
        load_output=True,
        preprocess_conf=args.preprocess_conf,
        preprocess_args={'train': False}  # Switch the mode of preprocessing
    )
    load_ul_tr = LoadInputsAndTargets(
        mode='asr',
        load_output=True,
        preprocess_conf=args.preprocess_conf,
        preprocess_args={'train': True}  # Switch the mode of preprocessing
    )
    load_ul_cv = LoadInputsAndTargets(
        mode='asr',
        load_output=True,
        preprocess_conf=args.preprocess_conf,
        preprocess_args={'train': False}  # Switch the mode of preprocessing
    )
    # hack to make batchsize argument as 1
    # actual bathsize is included in a list
    # default collate function converts numpy array to pytorch tensor
    # we used an empty collate function instead which returns list
    train_iter = ChainerDataLoader(
        dataset=TransformDataset(train,
                                 lambda data: converter([load_tr(data)])),
        batch_size=1,
        num_workers=args.n_iter_processes,
        shuffle=not use_sortagrad,
        collate_fn=lambda x: x[0],
    )
    valid_iter = ChainerDataLoader(
        dataset=TransformDataset(valid,
                                 lambda data: converter([load_cv(data)])),
        batch_size=1,
        shuffle=False,
        collate_fn=lambda x: x[0],
        num_workers=args.n_iter_processes,
    )
    # Add splited data iteration for training
    ul_train_iter = ChainerDataLoader(
        dataset=TransformDataset(ul_train,
                                 lambda data: converter([load_ul_tr(data)])),
        batch_size=1,
        shuffle=True,
        collate_fn=lambda x: x[0],
        num_workers=args.n_iter_processes,
    )
    ul_valid_iter = ChainerDataLoader(
        dataset=TransformDataset(ul_valid,
                                 lambda data: converter([load_ul_cv(data)])),
        batch_size=1,
        shuffle=False,
        collate_fn=lambda x: x[0],
        num_workers=args.n_iter_processes,
    )

    # Set up ICT related arguments
    ICT_args = {
        "consistency_rampup_starts": args.consistency_rampup_starts,
        "consistency_rampup_ends": args.consistency_rampup_ends,
        "cosine_rampdown_starts": args.cosine_rampdown_starts,
        "cosine_rampdown_ends": args.cosine_rampdown_ends,
        "ema_pre_decay": args.ema_pre_decay,
        "ema_post_decay": args.ema_post_decay
    }

    # Set up a trainer
    updater = CustomUpdater(
        model,
        args.grad_clip,
        {
            "main": train_iter,
            "sub": ul_train_iter
        },
        optimizer,
        device,
        args.ngpu,
        ICT_args,
        args.grad_noise,
        args.accum_grad,
        use_apex=use_apex,
    )
    trainer = training.Trainer(updater, (args.epochs, 'epoch'),
                               out=args.outdir)

    if use_sortagrad:
        trainer.extend(
            ShufflingEnabler([train_iter]),
            trigger=(args.sortagrad if args.sortagrad != -1 else args.epochs,
                     "epoch"),
        )

    # Resume from a snapshot
    if args.resume:
        logging.info("resumed from %s" % args.resume)
        torch_resume(args.resume, trainer)

    # Evaluate the model with the test dataset for each epoch
    # TODO: custom evaluator
    if args.save_interval_iters > 0:
        trainer.extend(
            CustomEvaluator(model, {
                "main": valid_iter,
                "sub": ul_valid_iter
            }, reporter, device, args.ngpu),
            trigger=(args.save_interval_iters, "iteration"),
        )
    else:
        trainer.extend(
            CustomEvaluator(model, {
                "main": valid_iter,
                "sub": ul_valid_iter
            }, reporter, device, args.ngpu))

    # Make a plot for training and validation values
    trainer.extend(
        extensions.PlotReport(
            [
                "main/loss",
                "validation/main/loss",
                "main/loss_ce",
                "validation/main/loss_ce",
                "main/loss_mse",
                "validation/main/loss_mse",
            ],
            "epoch",
            file_name="loss.png",
        ))
    trainer.extend(
        extensions.PlotReport(
            ["main/teacher_acc", "validation/main/teacher_acc"] +
            (["main/student_acc", "validation/main/student_acc"]
             if args.show_student_model_acc else []),
            "epoch",
            file_name="acc.png"))

    # Save best models
    trainer.extend(
        snapshot_object(model, "model.loss.best"),
        trigger=training.triggers.MinValueTrigger("validation/main/loss"),
    )
    trainer.extend(
        snapshot_object(model, "model.acc.best"),
        trigger=training.triggers.MaxValueTrigger(
            "validation/main/teacher_acc"),
    )

    # save snapshot which contains model and optimizer states
    if args.save_interval_iters > 0:
        trainer.extend(
            torch_snapshot(filename="snapshot.iter.{.updater.iteration}"),
            trigger=(args.save_interval_iters, "iteration"),
        )
    else:
        trainer.extend(torch_snapshot(), trigger=(1, "epoch"))

    # epsilon decay in the optimizer
    if args.opt == "adadelta":
        if args.criterion == "acc":
            trainer.extend(
                restore_snapshot(model,
                                 args.outdir + "/model.acc.best",
                                 load_fn=torch_load),
                trigger=CompareValueTrigger(
                    "validation/main/student_acc",
                    lambda best_value, current_value: best_value >
                    current_value,
                ),
            )
            trainer.extend(
                adadelta_eps_decay(args.eps_decay),
                trigger=CompareValueTrigger(
                    "validation/main/student_acc",
                    lambda best_value, current_value: best_value >
                    current_value,
                ),
            )
        elif args.criterion == "loss":
            trainer.extend(
                restore_snapshot(model,
                                 args.outdir + "/model.loss.best",
                                 load_fn=torch_load),
                trigger=CompareValueTrigger(
                    "validation/main/loss",
                    lambda best_value, current_value: best_value <
                    current_value,
                ),
            )
            trainer.extend(
                adadelta_eps_decay(args.eps_decay),
                trigger=CompareValueTrigger(
                    "validation/main/loss",
                    lambda best_value, current_value: best_value <
                    current_value,
                ),
            )

    # lr decay in rmsprop
    elif args.opt == "rmsprop" or "sgd":
        if args.criterion == "acc":
            trainer.extend(
                restore_snapshot(model,
                                 args.outdir + "/model.acc.best",
                                 load_fn=torch_load),
                trigger=CompareValueTrigger(
                    "validation/main/teacher_acc",
                    lambda best_value, current_value: best_value >
                    current_value,
                ),
            )
            trainer.extend(
                rmsprop_lr_decay(args.lr_decay),
                trigger=CompareValueTrigger(
                    "validation/main/teacher_acc",
                    lambda best_value, current_value: best_value >
                    current_value,
                ),
            )
        elif args.criterion == "loss":
            trainer.extend(
                restore_snapshot(model,
                                 args.outdir + "/model.loss.best",
                                 load_fn=torch_load),
                trigger=CompareValueTrigger(
                    "validation/main/loss",
                    lambda best_value, current_value: best_value <
                    current_value,
                ),
            )
            trainer.extend(
                rmsprop_lr_decay(args.lr_decay),
                trigger=CompareValueTrigger(
                    "validation/main/loss",
                    lambda best_value, current_value: best_value <
                    current_value,
                ),
            )

    # Write a log of evaluation statistics for each epoch
    trainer.extend(
        extensions.LogReport(trigger=(args.report_interval_iters,
                                      "iteration")))
    report_keys = [
        "epoch",
        "iteration",
        "main/loss",
        "main/loss_ce",
        "main/loss_mse",
        "validation/main/loss",
        "validation/main/loss_ce",
        "validation/main/loss_mse",
        "main/teacher_acc",
        "validation/main/teacher_acc",
        "elapsed_time",
    ] + ["main/student_acc", "validation/main/student_acc"
         ] if args.show_student_model_acc else []
    if args.opt == "adadelta":
        trainer.extend(
            extensions.observe_value(
                "eps",
                lambda trainer: trainer.updater.get_optimizer("main").
                param_groups[0]["eps"],
            ),
            trigger=(args.report_interval_iters, "iteration"),
        )
        report_keys.append("eps")
    if args.opt == "rmsprop" or "sgd":
        trainer.extend(
            extensions.observe_value(
                "lr",
                lambda trainer: trainer.updater.get_optimizer("main").
                param_groups[0]["lr"],
            ),
            trigger=(args.report_interval_iters, "iteration"),
        )
        report_keys.append("lr")
    trainer.extend(
        extensions.PrintReport(report_keys),
        trigger=(args.report_interval_iters, "iteration"),
    )

    trainer.extend(
        extensions.ProgressBar(update_interval=args.report_interval_iters))
    set_early_stop(trainer, args)

    if args.tensorboard_dir is not None and args.tensorboard_dir != "":
        trainer.extend(
            TensorboardLogger(SummaryWriter(args.tensorboard_dir), None),
            trigger=(args.report_interval_iters, "iteration"),
        )
    # Run the training
    trainer.run()
    check_early_stop(trainer, args.epochs)
Beispiel #15
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model',
                        choices=('ssd300', 'ssd512'),
                        default='ssd300')
    parser.add_argument('--batchsize', type=int, default=32)
    parser.add_argument('--np', type=int, default=8)
    parser.add_argument('--test-batchsize', type=int, default=16)
    parser.add_argument('--iteration', type=int, default=120000)
    parser.add_argument('--step', type=int, nargs='*', default=[80000, 100000])
    parser.add_argument('--lr', type=float, default=1e-3)
    parser.add_argument('--out', default='result')
    parser.add_argument('--resume')
    parser.add_argument('--dtype',
                        type=str,
                        choices=dtypes.keys(),
                        default='float32',
                        help='Select the data type of the model')
    parser.add_argument('--model-dir',
                        default=None,
                        type=str,
                        help='Where to store models')
    parser.add_argument('--dataset-dir',
                        default=None,
                        type=str,
                        help='Where to store datasets')
    parser.add_argument('--dynamic-interval',
                        default=None,
                        type=int,
                        help='Interval for dynamic loss scaling')
    parser.add_argument('--init-scale',
                        default=1,
                        type=float,
                        help='Initial scale for ada loss')
    parser.add_argument('--loss-scale-method',
                        default='approx_range',
                        type=str,
                        help='Method for adaptive loss scaling')
    parser.add_argument('--scale-upper-bound',
                        default=16,
                        type=float,
                        help='Hard upper bound for each scale factor')
    parser.add_argument('--accum-upper-bound',
                        default=1024,
                        type=float,
                        help='Accumulated upper bound for all scale factors')
    parser.add_argument('--update-per-n-iteration',
                        default=1,
                        type=int,
                        help='Update the loss scale value per n iteration')
    parser.add_argument('--snapshot-per-n-iteration',
                        default=10000,
                        type=int,
                        help='The frequency of taking snapshots')
    parser.add_argument('--n-uf', default=1e-3, type=float)
    parser.add_argument('--nosanity-check', default=False, action='store_true')
    parser.add_argument('--nouse-fp32-update',
                        default=False,
                        action='store_true')
    parser.add_argument('--profiling', default=False, action='store_true')
    parser.add_argument('--verbose',
                        action='store_true',
                        default=False,
                        help='Verbose output')
    args = parser.parse_args()

    # https://docs.chainer.org/en/stable/chainermn/tutorial/tips_faqs.html#using-multiprocessiterator
    if hasattr(multiprocessing, 'set_start_method'):
        multiprocessing.set_start_method('forkserver')
        p = multiprocessing.Process()
        p.start()
        p.join()

    comm = chainermn.create_communicator('pure_nccl')
    device = comm.intra_rank

    # Set up workspace
    # 12 GB GPU RAM for workspace
    chainer.cuda.set_max_workspace_size(16 * 1024 * 1024 * 1024)
    chainer.global_config.cv_resize_backend = 'cv2'

    # Setup the data type
    # when initializing models as follows, their data types will be casted.
    # Weethave to forbid the usage of cudnn
    if args.dtype != 'float32':
        chainer.global_config.use_cudnn = 'never'
    chainer.global_config.dtype = dtypes[args.dtype]
    print('==> Setting the data type to {}'.format(args.dtype))

    if args.model_dir is not None:
        chainer.dataset.set_dataset_root(args.model_dir)
    if args.model == 'ssd300':
        model = SSD300(n_fg_class=len(voc_bbox_label_names),
                       pretrained_model='imagenet')
    elif args.model == 'ssd512':
        model = SSD512(n_fg_class=len(voc_bbox_label_names),
                       pretrained_model='imagenet')

    model.use_preset('evaluate')

    ######################################
    # Setup model
    #######################################
    # Apply ada loss transform
    recorder = AdaLossRecorder(sample_per_n_iter=100)
    profiler = Profiler()
    sanity_checker = SanityChecker(
        check_per_n_iter=100) if not args.nosanity_check else None
    # Update the model to support AdaLoss
    # TODO: refactorize
    model_ = AdaLossScaled(
        model,
        init_scale=args.init_scale,
        cfg={
            'loss_scale_method': args.loss_scale_method,
            'scale_upper_bound': args.scale_upper_bound,
            'accum_upper_bound': args.accum_upper_bound,
            'update_per_n_iteration': args.update_per_n_iteration,
            'recorder': recorder,
            'profiler': profiler,
            'sanity_checker': sanity_checker,
            'n_uf_threshold': args.n_uf,
            # 'power_of_two': False,
        },
        transforms=[
            AdaLossTransformLinear(),
            AdaLossTransformConvolution2D(),
        ],
        verbose=args.verbose)

    if comm.rank == 0:
        print(model)

    train_chain = MultiboxTrainChain(model_, comm=comm)
    chainer.cuda.get_device_from_id(device).use()

    # to GPU
    model.coder.to_gpu()
    model.extractor.to_gpu()
    model.multibox.to_gpu()

    shared_mem = 100 * 1000 * 1000 * 4

    if args.dataset_dir is not None:
        chainer.dataset.set_dataset_root(args.dataset_dir)
    train = TransformDataset(
        ConcatenatedDataset(VOCBboxDataset(year='2007', split='trainval'),
                            VOCBboxDataset(year='2012', split='trainval')),
        ('img', 'mb_loc', 'mb_label'),
        Transform(model.coder,
                  model.insize,
                  model.mean,
                  dtype=dtypes[args.dtype]))

    if comm.rank == 0:
        indices = np.arange(len(train))
    else:
        indices = None
    indices = chainermn.scatter_dataset(indices, comm, shuffle=True)
    train = train.slice[indices]

    train_iter = chainer.iterators.MultiprocessIterator(train,
                                                        args.batchsize //
                                                        comm.size,
                                                        n_processes=8,
                                                        n_prefetch=2,
                                                        shared_mem=shared_mem)

    if comm.rank == 0:  # NOTE: only performed on the first device
        test = VOCBboxDataset(year='2007',
                              split='test',
                              use_difficult=True,
                              return_difficult=True)
        test_iter = chainer.iterators.SerialIterator(test,
                                                     args.test_batchsize,
                                                     repeat=False,
                                                     shuffle=False)

    # initial lr is set to 1e-3 by ExponentialShift
    optimizer = chainermn.create_multi_node_optimizer(
        chainer.optimizers.MomentumSGD(), comm)
    if args.dtype == 'mixed16':
        if not args.nouse_fp32_update:
            print('==> Using FP32 update for dtype=mixed16')
            optimizer.use_fp32_update()  # by default use fp32 update

        # HACK: support skipping update by existing loss scaling functionality
        if args.dynamic_interval is not None:
            optimizer.loss_scaling(interval=args.dynamic_interval, scale=None)
        else:
            optimizer.loss_scaling(interval=float('inf'), scale=None)
            optimizer._loss_scale_max = 1.0  # to prevent actual loss scaling

    optimizer.setup(train_chain)
    for param in train_chain.params():
        if param.name == 'b':
            param.update_rule.add_hook(GradientScaling(2))
        else:
            param.update_rule.add_hook(WeightDecay(0.0005))

    updater = training.updaters.StandardUpdater(train_iter,
                                                optimizer,
                                                device=device)
    # if args.dtype == 'mixed16':
    #     updater.loss_scale = 8
    iteration_interval = (args.iteration, 'iteration')

    trainer = training.Trainer(updater, iteration_interval, args.out)
    # trainer.extend(extensions.ExponentialShift('lr', 0.1, init=args.lr),
    #                trigger=triggers.ManualScheduleTrigger(
    #                    args.step, 'iteration'))
    if args.batchsize != 32:
        warmup_attr_ratio = 0.1
        # NOTE: this is confusing but it means n_iter
        warmup_n_epoch = 1000
        lr_shift = chainerlp.extensions.ExponentialShift(
            'lr',
            0.1,
            init=args.lr * warmup_attr_ratio,
            warmup_attr_ratio=warmup_attr_ratio,
            warmup_n_epoch=warmup_n_epoch,
            schedule=args.step)
        trainer.extend(lr_shift, trigger=(1, 'iteration'))

    if comm.rank == 0:
        if not args.profiling:
            trainer.extend(DetectionVOCEvaluator(
                test_iter,
                model,
                use_07_metric=True,
                label_names=voc_bbox_label_names),
                           trigger=triggers.ManualScheduleTrigger(
                               args.step + [args.iteration], 'iteration'))

        log_interval = 10, 'iteration'
        trainer.extend(extensions.LogReport(trigger=log_interval))
        trainer.extend(extensions.observe_lr(), trigger=log_interval)
        trainer.extend(extensions.observe_value(
            'loss_scale',
            lambda trainer: trainer.updater.get_optimizer('main')._loss_scale),
                       trigger=log_interval)

        metrics = [
            'epoch', 'iteration', 'lr', 'main/loss', 'main/loss/loc',
            'main/loss/conf', 'validation/main/map'
        ]
        if args.dynamic_interval is not None:
            metrics.insert(2, 'loss_scale')

        trainer.extend(extensions.PrintReport(metrics), trigger=log_interval)
        trainer.extend(extensions.ProgressBar(update_interval=10))

        trainer.extend(extensions.snapshot(),
                       trigger=(args.snapshot_per_n_iteration, 'iteration'))
        trainer.extend(extensions.snapshot_object(
            model, 'model_iter_{.updater.iteration}'),
                       trigger=(args.iteration, 'iteration'))

    if args.resume:
        serializers.load_npz(args.resume, trainer)

    hook = AdaLossMonitor(sample_per_n_iter=100,
                          verbose=args.verbose,
                          includes=['Grad', 'Deconvolution'])
    recorder.trainer = trainer
    hook.trainer = trainer

    with ExitStack() as stack:
        if comm.rank == 0:
            stack.enter_context(hook)
        trainer.run()

    # store recorded results
    if comm.rank == 0:  # NOTE: only export in the first rank
        recorder.export().to_csv(os.path.join(args.out, 'loss_scale.csv'))
        profiler.export().to_csv(os.path.join(args.out, 'profile.csv'))
        if sanity_checker:
            sanity_checker.export().to_csv(
                os.path.join(args.out, 'sanity_check.csv'))
        hook.export_history().to_csv(os.path.join(args.out, 'grad_stats.csv'))
def main():
    parser = argparse.ArgumentParser(description='AutoEncoder ShapeNet')
    # parser.add_argument('--conv-layers', '-c', type=int, default=4)
    parser.add_argument('--batchsize', '-b', type=int, default=32)
    parser.add_argument('--dropout_ratio', type=float, default=0)
    parser.add_argument('--num_point', '-n', type=int, default=1024)
    parser.add_argument('--gpu', '-g', type=int, default=-1)
    parser.add_argument('--out', '-o', type=str, default='result')
    parser.add_argument('--epoch', '-e', type=int, default=250)
    parser.add_argument('--model_filename',
                        '-m',
                        type=str,
                        default='model.npz')
    parser.add_argument('--resume', '-r', type=str, default='')
    parser.add_argument('--trans', '-t', type=strtobool, default='true')
    parser.add_argument('--use_bn', type=strtobool, default='true')
    parser.add_argument('--residual', type=strtobool, default='false')
    parser.add_argument('--use_val', '-v', type=strtobool, default='true')
    parser.add_argument('--class_choice', '-c', type=str, default='Chair')
    args = parser.parse_args()

    batch_size = args.batchsize
    dropout_ratio = args.dropout_ratio
    num_point = args.num_point
    device = args.gpu
    out_dir = args.out
    epoch = args.epoch
    model_filename = args.model_filename
    resume = args.resume
    trans = args.trans
    use_bn = args.use_bn
    residual = args.residual
    use_val = args.use_val
    class_choice = args.class_choice

    trans_lam1 = 0.001
    trans_lam2 = 0.001
    out_dim = 3
    in_dim = 3
    middle_dim = 64

    try:
        os.makedirs(out_dir, exist_ok=True)
        import chainerex.utils as cl
        fp = os.path.join(out_dir, 'args.json')
        cl.save_json(fp, vars(args))
        print('save args to', fp)
    except ImportError:
        pass

    # Network
    print('Train PointNet-AutoEncoder model... trans={} use_bn={} dropout={}'.
          format(trans, use_bn, dropout_ratio))
    model = ae.PointNetAE(out_dim=out_dim,
                          in_dim=in_dim,
                          middle_dim=middle_dim,
                          dropout_ratio=dropout_ratio,
                          use_bn=use_bn,
                          trans=trans,
                          trans_lam1=trans_lam1,
                          trans_lam2=trans_lam2,
                          residual=residual,
                          output_points=num_point)

    print("Dataset setting... num_point={} use_val={}".format(
        num_point, use_val))
    # Dataset preparation

    train = dataset.ChainerPointCloudDatasetDefault(
        split="train", class_choice=[class_choice], num_point=num_point)
    if use_val:
        val = dataset.ChainerPointCloudDatasetDefault(
            split="val", class_choice=[class_choice], num_point=num_point)
        val_iter = iterators.SerialIterator(ConcatenatedDataset(*([val])),
                                            batch_size,
                                            repeat=False,
                                            shuffle=False)
    train_iter = iterators.SerialIterator(ConcatenatedDataset(*([train])),
                                          batch_size)

    print("GPU setting...")
    # gpu setting
    if (device >= 0):
        print('using gpu {}'.format(device))
        chainer.backends.cuda.get_device_from_id(device).use()
        model.to_gpu()

    # Optimizer
    optimizer = optimizers.Adam()
    optimizer.setup(model)

    # traning
    converter = concat_examples
    updater = training.StandardUpdater(train_iter,
                                       optimizer,
                                       device=device,
                                       converter=converter)
    trainer = training.Trainer(updater, (epoch, 'epoch'), out=out_dir)

    from chainerex.training.extensions import schedule_optimizer_value
    from chainer.training.extensions import observe_value
    # trainer.extend(observe_lr)
    observation_key = 'lr'
    trainer.extend(
        observe_value(
            observation_key,
            lambda trainer: trainer.updater.get_optimizer('main').alpha))
    trainer.extend(
        schedule_optimizer_value(
            [10, 20, 100, 150, 200, 230],
            [0.003, 0.001, 0.0003, 0.0001, 0.00003, 0.00001]))

    if use_val:
        trainer.extend(
            E.Evaluator(val_iter, model, converter=converter, device=device))
        trainer.extend(
            E.PrintReport([
                'epoch', 'main/loss', 'main/dist_loss', 'main/trans_loss1',
                'main/trans_loss2', 'validation/main/loss',
                'validation/main/dist_loss', 'validation/main/trans_loss1',
                'validation/main/trans_loss2', 'lr', 'elapsed_time'
            ]))
    else:
        trainer.extend(
            E.PrintReport([
                'epoch', 'main/loss', 'main/dist_loss', 'main/trans_loss1',
                'main/trans_loss2', 'lr', 'elapsed_time'
            ]))
    trainer.extend(E.snapshot(), trigger=(epoch, 'epoch'))
    trainer.extend(E.LogReport())
    trainer.extend(E.ProgressBar(update_interval=10))

    resume = ''
    if resume:
        serializers.load_npz(resume, trainer)
    print("Traning start.")
    trainer.run()

    # --- save classifier ---
    # protocol = args.protocol
    # classifier.save_pickle(
    #     os.path.join(out_dir, args.model_filename), protocol=protocol)
    serializers.save_npz(os.path.join(out_dir, model_filename), model)
def main():
    # Start the multiprocessing environment
    # https://docs.chainer.org/en/stable/chainermn/tutorial/tips_faqs.html#using-multiprocessiterator
    if hasattr(multiprocessing, 'set_start_method'):
        multiprocessing.set_start_method('forkserver')
        p = multiprocessing.Process()
        p.start()
        p.join()

    # Set up workspace
    # 12 GB GPU RAM for workspace
    chainer.cuda.set_max_workspace_size(16 * 1024 * 1024 * 1024)

    # Setup the multi-node environment
    comm = chainermn.create_communicator(args.communicator)
    device = comm.intra_rank
    print(
        '==> Successfully setup communicator: "{}" rank: {} device: {} size: {}'
        .format(args.communicator, comm.rank, device, comm.size))
    set_random_seed(args, device)

    # Setup LR
    if args.lr is not None:
        lr = args.lr
    else:
        lr = 0.1 * (args.batchsize * comm.size) / 256  # TODO: why?
        if comm.rank == 0:
            print(
                'LR = {} is selected based on the linear scaling rule'.format(
                    lr))

    # Setup dataset
    train_dir = os.path.join(args.dataset_dir, 'train')
    val_dir = os.path.join(args.dataset_dir, 'val')
    label_names = datasets.directory_parsing_label_names(train_dir)
    train_data = datasets.DirectoryParsingLabelDataset(train_dir)
    val_data = datasets.DirectoryParsingLabelDataset(val_dir)
    train_data = TransformDataset(train_data, ('img', 'label'),
                                  TrainTransform(_mean, args))
    val_data = TransformDataset(val_data, ('img', 'label'),
                                ValTransform(_mean, args))
    print('==> [{}] Successfully finished loading dataset'.format(comm.rank))

    # Initializing dataset iterators
    if comm.rank == 0:
        train_indices = np.arange(len(train_data))
        val_indices = np.arange(len(val_data))
    else:
        train_indices = None
        val_indices = None

    train_indices = chainermn.scatter_dataset(train_indices,
                                              comm,
                                              shuffle=True)
    val_indices = chainermn.scatter_dataset(val_indices, comm, shuffle=True)
    train_data = train_data.slice[train_indices]
    val_data = val_data.slice[val_indices]
    train_iter = chainer.iterators.MultiprocessIterator(
        train_data, args.batchsize, n_processes=args.loaderjob)
    val_iter = iterators.MultiprocessIterator(val_data,
                                              args.batchsize,
                                              repeat=False,
                                              shuffle=False,
                                              n_processes=args.loaderjob)

    # Create the model
    kwargs = {}
    if args.first_bn_mixed16 and args.dtype == 'float16':
        print('==> Setting the first BN layer to mixed16')
        kwargs['first_bn_mixed16'] = True

    # Initialize the model
    net = models.__dict__[args.arch](n_class=len(label_names), **kwargs)
    # Following https://arxiv.org/pdf/1706.02677.pdf,
    # the gamma of the last BN of each resblock is initialized by zeros.
    for l in net.links():
        if isinstance(l, Bottleneck):
            l.conv3.bn.gamma.data[:] = 0

    # Apply ada loss transform
    recorder = AdaLossRecorder(sample_per_n_iter=100)
    # Update the model to support AdaLoss
    net = AdaLossScaled(net,
                        init_scale=args.init_scale,
                        cfg={
                            'loss_scale_method': args.loss_scale_method,
                            'scale_upper_bound': args.scale_upper_bound,
                            'accum_upper_bound': args.accum_upper_bound,
                            'update_per_n_iteration':
                            args.update_per_n_iteration,
                            'recorder': recorder,
                        },
                        transforms=[
                            AdaLossTransformLinear(),
                            AdaLossTransformBottleneck(),
                            AdaLossTransformBasicBlock(),
                            AdaLossTransformConv2DBNActiv(),
                        ],
                        verbose=args.verbose)

    if comm.rank == 0:  # print network only in the 1-rank machine
        print(net)
    net = L.Classifier(net)
    hook = AdaLossMonitor(sample_per_n_iter=100,
                          verbose=args.verbose,
                          includes=['Grad', 'Deconvolution'])

    # Setup optimizer
    optim = chainermn.create_multi_node_optimizer(
        optimizers.CorrectedMomentumSGD(lr=lr, momentum=args.momentum), comm)
    if args.dtype == 'mixed16':
        print('==> Using FP32 update for dtype=mixed16')
        optim.use_fp32_update()  # by default use fp32 update

        # HACK: support skipping update by existing loss scaling functionality
        if args.dynamic_interval is not None:
            optim.loss_scaling(interval=args.dynamic_interval, scale=None)
        else:
            optim.loss_scaling(interval=float('inf'), scale=None)
            optim._loss_scale_max = 1.0  # to prevent actual loss scaling

    optim.setup(net)

    # setup weight decay
    for param in net.params():
        if param.name not in ('beta', 'gamma'):
            param.update_rule.add_hook(WeightDecay(args.weight_decay))

    # allocate model to multiple GPUs
    if device >= 0:
        chainer.cuda.get_device(device).use()
        net.to_gpu()

    # Create an updater that implements how to update based on one train_iter input
    updater = chainer.training.StandardUpdater(train_iter,
                                               optim,
                                               device=device)
    # Setup Trainer
    stop_trigger = (args.epoch, 'epoch')
    if args.iter is not None:
        stop_trigger = (args.iter, 'iteration')
    trainer = training.Trainer(updater, stop_trigger, out=args.out)

    @make_shift('lr')
    def warmup_and_exponential_shift(trainer):
        """ LR schedule for training ResNet especially.
        NOTE: lr should be within the context.
        """
        epoch = trainer.updater.epoch_detail
        warmup_epoch = 5  # NOTE: mentioned the original ResNet paper.
        if epoch < warmup_epoch:
            if lr > 0.1:
                warmup_rate = 0.1 / lr
                rate = warmup_rate \
                    + (1 - warmup_rate) * epoch / warmup_epoch
            else:
                rate = 1
        elif epoch < 30:
            rate = 1
        elif epoch < 60:
            rate = 0.1
        elif epoch < 80:
            rate = 0.01
        else:
            rate = 0.001
        return rate * lr

    trainer.extend(warmup_and_exponential_shift)
    evaluator = chainermn.create_multi_node_evaluator(
        extensions.Evaluator(val_iter, net, device=device), comm)
    trainer.extend(evaluator, trigger=(1, 'epoch'))

    log_interval = 0.1, 'epoch'
    print_interval = 0.1, 'epoch'

    if comm.rank == 0:
        print('==========================================')
        print('Num process (COMM_WORLD): {}'.format(comm.size))
        print('Using {} communicator'.format(args.communicator))
        print('Num Minibatch-size: {}'.format(args.batchsize))
        print('Num epoch: {}'.format(args.epoch))
        print('==========================================')

        trainer.extend(chainer.training.extensions.observe_lr(),
                       trigger=log_interval)

        # NOTE: may take snapshot every iteration now
        snapshot_label = 'epoch' if args.iter is None else 'iteration'
        snapshot_trigger = (args.snapshot_freq, snapshot_label)
        snapshot_filename = ('snapshot_' + snapshot_label + '_{.updater.' +
                             snapshot_label + '}.npz')
        trainer.extend(extensions.snapshot(filename=snapshot_filename),
                       trigger=snapshot_trigger)

        trainer.extend(extensions.LogReport(trigger=log_interval))
        trainer.extend(extensions.observe_value(
            'loss_scale',
            lambda trainer: trainer.updater.get_optimizer('main')._loss_scale),
                       trigger=log_interval)
        trainer.extend(extensions.PrintReport([
            'iteration', 'epoch', 'elapsed_time', 'lr', 'loss_scale',
            'main/loss', 'validation/main/loss', 'main/accuracy',
            'validation/main/accuracy'
        ]),
                       trigger=print_interval)
        trainer.extend(extensions.ProgressBar(update_interval=10))

    if args.resume:
        serializers.load_npz(args.resume, trainer)

    recorder.trainer = trainer
    hook.trainer = trainer
    with ExitStack() as stack:
        if comm.rank == 0:
            stack.enter_context(hook)
        trainer.run()

    # store recorded results
    if comm.rank == 0:  # NOTE: only export in the first rank
        recorder.export().to_csv(os.path.join(args.out, 'loss_scale.csv'))
        hook.export_history().to_csv(os.path.join(args.out, 'grad_stats.csv'))
Beispiel #18
0
def train(args):
    '''Run training'''
    # display chainer version
    logging.info('chainer version = ' + chainer.__version__)

    # seed setting (chainer seed may not need it)
    os.environ['CHAINER_SEED'] = str(args.seed)
    logging.info('chainer seed = ' + os.environ['CHAINER_SEED'])

    # debug mode setting
    # 0 would be fastest, but 1 seems to be reasonable
    # by considering reproducability
    # revmoe type check
    if args.debugmode < 2:
        chainer.config.type_check = False
        logging.info('chainer type check is disabled')
    # use determinisitic computation or not
    if args.debugmode < 1:
        chainer.config.cudnn_deterministic = False
        logging.info('chainer cudnn deterministic is disabled')
    else:
        chainer.config.cudnn_deterministic = True

    # check cuda and cudnn availability
    if not chainer.cuda.available:
        logging.warning('cuda is not available')
    if not chainer.cuda.cudnn_enabled:
        logging.warning('cudnn is not available')

    # get input and output dimension info
    with open(args.valid_label, 'rb') as f:
        valid_json = json.load(f)['utts']
    utts = list(valid_json.keys())
    idim = int(valid_json[utts[0]]['idim'])
    odim = int(valid_json[utts[0]]['odim'])
    logging.info('#input dims : ' + str(idim))
    logging.info('#output dims: ' + str(odim))

    # check attention type
    if args.atype not in ['noatt', 'dot', 'location']:
        raise NotImplementedError(
            'chainer supports only noatt, dot, and location attention.')

    # specify model architecture
    e2e = E2E(idim, odim, args)
    model = Loss(e2e, args.mtlalpha)

    # write model config
    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)
    model_conf = args.outdir + '/model.conf'
    with open(model_conf, 'wb') as f:
        logging.info('writing a model config file to' + model_conf)
        # TODO(watanabe) use others than pickle, possibly json, and save as a text
        pickle.dump((idim, odim, args), f)
    for key in sorted(vars(args).keys()):
        logging.info('ARGS: ' + key + ': ' + str(vars(args)[key]))

    # Set gpu
    gpu_id = int(args.gpu)
    logging.info('gpu id: ' + str(gpu_id))
    if gpu_id >= 0:
        # Make a specified GPU current
        chainer.cuda.get_device_from_id(gpu_id).use()
        model.to_gpu()  # Copy the model to the GPU

    # Setup an optimizer
    if args.opt == 'adadelta':
        optimizer = chainer.optimizers.AdaDelta(eps=args.eps)
    elif args.opt == 'adam':
        optimizer = chainer.optimizers.Adam()
    optimizer.setup(model)
    optimizer.add_hook(chainer.optimizer.GradientClipping(args.grad_clip))

    # read json data
    with open(args.train_label, 'rb') as f:
        train_json = json.load(f)['utts']
    with open(args.valid_label, 'rb') as f:
        valid_json = json.load(f)['utts']

    # make minibatch list (variable length)
    train = make_batchset(train_json, args.batch_size, args.maxlen_in,
                          args.maxlen_out, args.minibatches)
    valid = make_batchset(valid_json, args.batch_size, args.maxlen_in,
                          args.maxlen_out, args.minibatches)
    # hack to make batchsze argument as 1
    # actual bathsize is included in a list
    train_iter = chainer.iterators.SerialIterator(train, 1)
    valid_iter = chainer.iterators.SerialIterator(valid,
                                                  1,
                                                  repeat=False,
                                                  shuffle=False)

    # prepare Kaldi reader
    train_reader = lazy_io.read_dict_scp(args.train_feat)
    valid_reader = lazy_io.read_dict_scp(args.valid_feat)

    # Set up a trainer
    updater = ChainerSeqUpdaterKaldi(train_iter, optimizer, train_reader,
                                     gpu_id)
    trainer = training.Trainer(updater, (args.epochs, 'epoch'),
                               out=args.outdir)

    # Resume from a snapshot
    if args.resume:
        chainer.serializers.load_npz(args.resume, trainer)

    # Evaluate the model with the test dataset for each epoch
    trainer.extend(
        ChainerSeqEvaluaterKaldi(valid_iter,
                                 model,
                                 valid_reader,
                                 device=gpu_id))

    # Take a snapshot for each specified epoch
    trainer.extend(extensions.snapshot(), trigger=(1, 'epoch'))

    # Make a plot for training and validation values
    trainer.extend(
        extensions.PlotReport([
            'main/loss', 'validation/main/loss', 'main/loss_ctc',
            'validation/main/loss_ctc', 'main/loss_att',
            'validation/main/loss_att'
        ],
                              'epoch',
                              file_name='loss.png'))
    trainer.extend(
        extensions.PlotReport(['main/acc', 'validation/main/acc'],
                              'epoch',
                              file_name='acc.png'))

    # Save best models
    trainer.extend(
        extensions.snapshot_object(model, 'model.loss.best'),
        trigger=training.triggers.MinValueTrigger('validation/main/loss'))
    trainer.extend(
        extensions.snapshot_object(model, 'model.acc.best'),
        trigger=training.triggers.MaxValueTrigger('validation/main/acc'))

    # epsilon decay in the optimizer
    if args.opt == 'adadelta':
        if args.criterion == 'acc':
            trainer.extend(restore_snapshot(model,
                                            args.outdir + '/model.acc.best'),
                           trigger=CompareValueTrigger(
                               'validation/main/acc', lambda best_value,
                               current_value: best_value > current_value))
            trainer.extend(adadelta_eps_decay(args.eps_decay),
                           trigger=CompareValueTrigger(
                               'validation/main/acc', lambda best_value,
                               current_value: best_value > current_value))
        elif args.criterion == 'loss':
            trainer.extend(restore_snapshot(model,
                                            args.outdir + '/model.loss.best'),
                           trigger=CompareValueTrigger(
                               'validation/main/loss', lambda best_value,
                               current_value: best_value < current_value))
            trainer.extend(adadelta_eps_decay(args.eps_decay),
                           trigger=CompareValueTrigger(
                               'validation/main/loss', lambda best_value,
                               current_value: best_value < current_value))

    # Write a log of evaluation statistics for each epoch
    trainer.extend(extensions.LogReport(trigger=(100, 'iteration')))
    report_keys = [
        'epoch', 'iteration', 'main/loss', 'main/loss_ctc', 'main/loss_att',
        'validation/main/loss', 'validation/main/loss_ctc',
        'validation/main/loss_att', 'main/acc', 'validation/main/acc',
        'elapsed_time'
    ]
    if args.opt == 'adadelta':
        trainer.extend(extensions.observe_value(
            'eps', lambda trainer: trainer.updater.get_optimizer('main').eps),
                       trigger=(100, 'iteration'))
        report_keys.append('eps')
    trainer.extend(extensions.PrintReport(report_keys),
                   trigger=(100, 'iteration'))

    trainer.extend(extensions.ProgressBar())

    # Run the training
    trainer.run()
Beispiel #19
0
def main():
    if not FLAGS.do_train and not FLAGS.do_predict and not FLAGS.do_print_test:
        raise ValueError(
            "At least one of `do_train` or `do_predict` must be True.")

    if FLAGS.do_train:
        if not FLAGS.train_file:
            raise ValueError(
                "If `do_train` is True, then `train_file` must be specified.")
    if FLAGS.do_predict:
        if not FLAGS.predict_file:
            raise ValueError(
                "If `do_predict` is True, then `predict_file` must be specified."
            )

    bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)

    if FLAGS.max_seq_length > bert_config.max_position_embeddings:
        raise ValueError(
            "Cannot use sequence length %d because the BERT model "
            "was only trained up to sequence length %d" %
            (FLAGS.max_seq_length, bert_config.max_position_embeddings))

    if not os.path.isdir(FLAGS.output_dir):
        os.makedirs(FLAGS.output_dir)

    tokenizer = tokenization.FullTokenizer(vocab_file=FLAGS.vocab_file,
                                           do_lower_case=FLAGS.do_lower_case)

    train_examples = None
    num_train_steps = None
    num_warmup_steps = None
    if FLAGS.do_train:
        train_examples = read_squad_examples(input_file=FLAGS.train_file,
                                             is_training=True)
        train_features = convert_examples_to_features(train_examples,
                                                      tokenizer,
                                                      FLAGS.max_seq_length,
                                                      FLAGS.doc_stride,
                                                      FLAGS.max_query_length,
                                                      is_training=True)
        num_train_steps = int(
            len(train_examples) / FLAGS.train_batch_size *
            FLAGS.num_train_epochs)
        num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion)

    bert = modeling.BertModel(config=bert_config)
    model = modeling.BertSQuAD(bert)
    if FLAGS.do_train:
        # If training, load BERT parameters only.
        ignore_names = ['output/W', 'output/b']
    else:
        # If only do_predict, load all parameters.
        ignore_names = None
    chainer.serializers.load_npz(FLAGS.init_checkpoint,
                                 model,
                                 ignore_names=ignore_names)

    if FLAGS.gpu >= 0:
        chainer.backends.cuda.get_device_from_id(FLAGS.gpu).use()
        model.to_gpu()

    if FLAGS.do_train:
        # Adam with weight decay only for 2D matrices
        optimizer = optimization.WeightDecayForMatrixAdam(
            alpha=1.,  # ignore alpha. instead, use eta as actual lr
            eps=1e-6,
            weight_decay_rate=0.01)
        optimizer.setup(model)
        optimizer.add_hook(chainer.optimizer.GradientClipping(1.))

        train_iter = chainer.iterators.SerialIterator(train_features,
                                                      FLAGS.train_batch_size)
        converter = Converter(is_training=True)
        updater = training.updaters.StandardUpdater(
            train_iter,
            optimizer,
            converter=converter,
            device=FLAGS.gpu,
            loss_func=model.compute_loss)
        trainer = training.Trainer(updater, (num_train_steps, 'iteration'),
                                   out=FLAGS.output_dir)

        # learning rate (eta) scheduling in Adam
        lr_decay_init = FLAGS.learning_rate * \
            (num_train_steps - num_warmup_steps) / num_train_steps
        trainer.extend(
            extensions.LinearShift(  # decay
                'eta', (lr_decay_init, 0.),
                (num_warmup_steps, num_train_steps)))
        trainer.extend(
            extensions.WarmupShift(  # warmup
                'eta', 0., num_warmup_steps, FLAGS.learning_rate))
        trainer.extend(extensions.observe_value(
            'eta', lambda trainer: trainer.updater.get_optimizer('main').eta),
                       trigger=(100, 'iteration'))  # logging

        trainer.extend(extensions.snapshot_object(
            model, 'model_snapshot_iter_{.updater.iteration}.npz'),
                       trigger=(num_train_steps // 2, 'iteration'))  # TODO
        trainer.extend(extensions.LogReport(trigger=(100, 'iteration')))
        trainer.extend(
            extensions.PrintReport([
                'iteration', 'main/loss', 'main/accuracy', 'elapsed_time',
                'eta'
            ]))
        trainer.extend(extensions.ProgressBar(update_interval=10))

        trainer.run()

    if FLAGS.do_predict:
        eval_examples = read_squad_examples(input_file=FLAGS.predict_file,
                                            is_training=False)
        eval_features = convert_examples_to_features(eval_examples,
                                                     tokenizer,
                                                     FLAGS.max_seq_length,
                                                     FLAGS.doc_stride,
                                                     FLAGS.max_query_length,
                                                     is_training=False)
        test_iter = chainer.iterators.SerialIterator(eval_features,
                                                     FLAGS.predict_batch_size,
                                                     repeat=False,
                                                     shuffle=False)
        converter = Converter(is_training=False)

        print('Evaluating ...')
        evaluate(eval_examples,
                 test_iter,
                 model,
                 converter=converter,
                 device=FLAGS.gpu,
                 predict_func=model.predict)
        print('Finished.')
Beispiel #20
0
def main():
    parser = argparse.ArgumentParser(description='ModelNet40 classification')
    # parser.add_argument('--conv-layers', '-c', type=int, default=4)
    parser.add_argument('--method', '-m', type=str, default='point_cls')
    parser.add_argument('--batchsize', '-b', type=int, default=32)
    parser.add_argument('--dropout_ratio', type=float, default=0.3)
    parser.add_argument('--num_point', type=int, default=1024)
    parser.add_argument('--gpu', '-g', type=int, default=-1)
    parser.add_argument('--out', '-o', type=str, default='result')
    parser.add_argument('--epoch', '-e', type=int, default=250)
    # parser.add_argument('--unit-num', '-u', type=int, default=16)
    parser.add_argument('--seed', '-s', type=int, default=777)
    parser.add_argument('--protocol', type=int, default=2)
    parser.add_argument('--model_filename', type=str, default='model.npz')
    parser.add_argument('--resume', type=str, default='')
    parser.add_argument('--trans', type=strtobool, default='true')
    parser.add_argument('--use_bn', type=strtobool, default='true')
    parser.add_argument('--normalize', type=strtobool, default='false')
    parser.add_argument('--residual', type=strtobool, default='false')
    args = parser.parse_args()

    seed = args.seed
    method = args.method
    num_point = args.num_point
    out_dir = args.out
    num_class = 40
    debug = False
    try:
        os.makedirs(out_dir, exist_ok=True)
        import chainerex.utils as cl
        fp = os.path.join(out_dir, 'args.json')
        cl.save_json(fp, vars(args))
        print('save args to', fp)
    except ImportError:
        pass

    # Dataset preparation
    train = get_train_dataset(num_point=num_point)
    val = get_test_dataset(num_point=num_point)
    if method == 'kdnet_cls' or method == 'kdcontextnet_cls':
        from chainer_pointnet.utils.kdtree import TransformKDTreeCls
        max_level = calc_max_level(num_point)
        print('kdnet_cls max_level {}'.format(max_level))
        return_split_dims = (method == 'kdnet_cls')
        train = TransformDataset(
            train,
            TransformKDTreeCls(max_level=max_level,
                               return_split_dims=return_split_dims))
        val = TransformDataset(
            val,
            TransformKDTreeCls(max_level=max_level,
                               return_split_dims=return_split_dims))
        if method == 'kdnet_cls':
            # Debug print
            points, split_dims, t = train[0]
            print('converted to kdnet dataset train', points.shape,
                  split_dims.shape, t)
            points, split_dims, t = val[0]
            print('converted to kdnet dataset val', points.shape,
                  split_dims.shape, t)
        if method == 'kdcontextnet_cls':
            # Debug print
            points, t = train[0]
            print('converted to kdcontextnet dataset train', points.shape, t)
            points, t = val[0]
            print('converted to kdcontextnet dataset val', points.shape, t)

    if debug:
        # use few train dataset
        train = SubDataset(train, 0, 50)

    # Network
    # n_unit = args.unit_num
    # conv_layers = args.conv_layers
    trans = args.trans
    use_bn = args.use_bn
    normalize = args.normalize
    residual = args.residual
    dropout_ratio = args.dropout_ratio
    from chainer.dataset.convert import concat_examples
    converter = concat_examples

    if method == 'point_cls':
        print(
            'Train PointNetCls model... trans={} use_bn={} dropout={}'.format(
                trans, use_bn, dropout_ratio))
        model = PointNetCls(out_dim=num_class,
                            in_dim=3,
                            middle_dim=64,
                            dropout_ratio=dropout_ratio,
                            trans=trans,
                            trans_lam1=0.001,
                            trans_lam2=0.001,
                            use_bn=use_bn,
                            residual=residual)
    elif method == 'point2_cls_ssg':
        print('Train PointNet2ClsSSG model... use_bn={} dropout={}'.format(
            use_bn, dropout_ratio))
        model = PointNet2ClsSSG(out_dim=num_class,
                                in_dim=3,
                                dropout_ratio=dropout_ratio,
                                use_bn=use_bn,
                                residual=residual)
    elif method == 'point2_cls_msg':
        print('Train PointNet2ClsMSG model... use_bn={} dropout={}'.format(
            use_bn, dropout_ratio))
        model = PointNet2ClsMSG(out_dim=num_class,
                                in_dim=3,
                                dropout_ratio=dropout_ratio,
                                use_bn=use_bn,
                                residual=residual)
    elif method == 'kdnet_cls':
        print('Train KDNetCls model... use_bn={} dropout={}'.format(
            use_bn, dropout_ratio))
        model = KDNetCls(
            out_dim=num_class,
            in_dim=3,
            dropout_ratio=dropout_ratio,
            use_bn=use_bn,
            max_level=max_level,
        )

        def kdnet_converter(batch, device=None, padding=None):
            # concat_examples to CPU at first.
            result = concat_examples(batch, device=None, padding=padding)
            out_list = []
            for elem in result:
                if elem.dtype != object:
                    # Send to GPU for int/float dtype array.
                    out_list.append(to_device(device, elem))
                else:
                    # Do NOT send to GPU for dtype=object array.
                    out_list.append(elem)
            return tuple(out_list)

        converter = kdnet_converter
    elif method == 'kdcontextnet_cls':
        print('Train KDContextNetCls model... use_bn={} dropout={}'
              'normalize={} residual={}'.format(use_bn, dropout_ratio,
                                                normalize, residual))
        model = KDContextNetCls(
            out_dim=num_class,
            in_dim=3,
            dropout_ratio=dropout_ratio,
            use_bn=use_bn,
            # Below is for non-default customization
            levels=[3, 6, 9],
            feature_learning_mlp_list=[[32, 32, 128], [64, 64, 256],
                                       [128, 128, 512]],
            feature_aggregation_mlp_list=[[128], [256], [512]],
            normalize=normalize,
            residual=residual)
    else:
        raise ValueError('[ERROR] Invalid method {}'.format(method))

    train_iter = iterators.SerialIterator(train, args.batchsize)
    val_iter = iterators.SerialIterator(val,
                                        args.batchsize,
                                        repeat=False,
                                        shuffle=False)

    device = args.gpu
    # classifier = Classifier(model, device=device)
    classifier = model
    load_model = False
    if load_model:
        serializers.load_npz(os.path.join(out_dir, args.model_filename),
                             classifier)
    if device >= 0:
        print('using gpu {}'.format(device))
        chainer.cuda.get_device_from_id(device).use()
        classifier.to_gpu()  # Copy the model to the GPU

    optimizer = optimizers.Adam()
    optimizer.setup(classifier)

    updater = training.StandardUpdater(train_iter,
                                       optimizer,
                                       converter=converter,
                                       device=args.gpu)

    trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=out_dir)

    from chainerex.training.extensions import schedule_optimizer_value
    from chainer.training.extensions import observe_value
    # trainer.extend(observe_lr)
    observation_key = 'lr'
    trainer.extend(
        observe_value(
            observation_key,
            lambda trainer: trainer.updater.get_optimizer('main').alpha))
    trainer.extend(
        schedule_optimizer_value(
            [10, 20, 100, 150, 200, 230],
            [0.003, 0.001, 0.0003, 0.0001, 0.00003, 0.00001]))

    trainer.extend(
        E.Evaluator(val_iter, classifier, converter=converter,
                    device=args.gpu))
    trainer.extend(E.snapshot(), trigger=(args.epoch, 'epoch'))
    trainer.extend(E.LogReport())
    trainer.extend(
        E.PrintReport([
            'epoch',
            'main/loss',
            'main/cls_loss',
            'main/trans_loss1',
            'main/trans_loss2',
            'main/accuracy',
            'validation/main/loss',
            # 'validation/main/cls_loss',
            # 'validation/main/trans_loss1', 'validation/main/trans_loss2',
            'validation/main/accuracy',
            'lr',
            'elapsed_time'
        ]))
    trainer.extend(E.ProgressBar(update_interval=10))

    if args.resume:
        serializers.load_npz(args.resume, trainer)
    trainer.run()

    # --- save classifier ---
    # protocol = args.protocol
    # classifier.save_pickle(
    #     os.path.join(out_dir, args.model_filename), protocol=protocol)
    serializers.save_npz(os.path.join(out_dir, args.model_filename),
                         classifier)
Beispiel #21
0
def main():
  parser = argparse.ArgumentParser()
  parser.add_argument('--model',
                      choices=('ssd300', 'ssd512'),
                      default='ssd300')
  parser.add_argument('--batchsize', type=int, default=32)
  parser.add_argument('--test-batchsize', type=int, default=16)
  parser.add_argument('--iteration', type=int, default=120000)
  parser.add_argument('--step', type=int, nargs='*', default=[80000, 100000])
  parser.add_argument('--gpu', type=int, default=-1)
  parser.add_argument('--out', default='result')
  parser.add_argument('--resume')
  parser.add_argument('--dtype',
                      type=str,
                      choices=dtypes.keys(),
                      default='float32',
                      help='Select the data type of the model')
  parser.add_argument('--model-dir',
                      default=None,
                      type=str,
                      help='Where to store models')
  parser.add_argument('--dataset-dir',
                      default=None,
                      type=str,
                      help='Where to store datasets')
  parser.add_argument('--dynamic-interval',
                      default=None,
                      type=int,
                      help='Interval for dynamic loss scaling')
  parser.add_argument('--init-scale',
                      default=1,
                      type=float,
                      help='Initial scale for ada loss')
  parser.add_argument('--loss-scale-method',
                      default='approx_range',
                      type=str,
                      help='Method for adaptive loss scaling')
  parser.add_argument('--scale-upper-bound',
                      default=32800,
                      type=float,
                      help='Hard upper bound for each scale factor')
  parser.add_argument('--accum-upper-bound',
                      default=32800,
                      type=float,
                      help='Accumulated upper bound for all scale factors')
  parser.add_argument('--update-per-n-iteration',
                      default=100,
                      type=int,
                      help='Update the loss scale value per n iteration')
  parser.add_argument('--snapshot-per-n-iteration',
                      default=10000,
                      type=int,
                      help='The frequency of taking snapshots')
  parser.add_argument('--n-uf', default=1e-3, type=float)
  parser.add_argument('--nosanity-check', default=False, action='store_true')
  parser.add_argument('--nouse-fp32-update',
                      default=False, action='store_true')
  parser.add_argument('--profiling', default=False, action='store_true')
  parser.add_argument('--verbose',
                      action='store_true',
                      default=False,
                      help='Verbose output')
  args = parser.parse_args()

  # Setting data types
  if args.dtype != 'float32':
    chainer.global_config.use_cudnn = 'never'
  chainer.global_config.dtype = dtypes[args.dtype]
  print('==> Setting the data type to {}'.format(args.dtype))

  # Initialize model
  if args.model == 'ssd300':
    model = SSD300(n_fg_class=len(voc_bbox_label_names),
                   pretrained_model='imagenet')
  elif args.model == 'ssd512':
    model = SSD512(n_fg_class=len(voc_bbox_label_names),
                   pretrained_model='imagenet')

  model.use_preset('evaluate')

  # Apply adaptive loss scaling
  recorder = AdaLossRecorder(sample_per_n_iter=100)
  profiler = Profiler()
  sanity_checker = SanityChecker(check_per_n_iter=100) if not args.nosanity_check else None
  # Update the model to support AdaLoss
  # TODO: refactorize
  model_ = AdaLossScaled(
      model,
      init_scale=args.init_scale,
      cfg={
          'loss_scale_method': args.loss_scale_method,
          'scale_upper_bound': args.scale_upper_bound,
          'accum_upper_bound': args.accum_upper_bound,
          'update_per_n_iteration': args.update_per_n_iteration,
          'recorder': recorder,
          'profiler': profiler,
          'sanity_checker': sanity_checker,
          'n_uf_threshold': args.n_uf,
      },
      transforms=[
          AdaLossTransformLinear(),
          AdaLossTransformConvolution2D(),
      ],
      verbose=args.verbose)

  # Finalize the model
  train_chain = MultiboxTrainChain(model_)
  if args.gpu >= 0:
    chainer.cuda.get_device_from_id(args.gpu).use()
    cp.random.seed(0)

    # NOTE: we have to transfer modules explicitly to GPU
    model.coder.to_gpu()
    model.extractor.to_gpu()
    model.multibox.to_gpu()

  # Prepare dataset
  if args.model_dir is not None:
    chainer.dataset.set_dataset_root(args.model_dir)
  train = TransformDataset(
      ConcatenatedDataset(VOCBboxDataset(year='2007', split='trainval'),
                          VOCBboxDataset(year='2012', split='trainval')),
      Transform(model.coder, model.insize, model.mean, dtype=dtypes[args.dtype]))
  # train_iter = chainer.iterators.MultiprocessIterator(
  #     train, args.batchsize) # , n_processes=8, n_prefetch=2)
  train_iter = chainer.iterators.MultithreadIterator(train, args.batchsize)
  # train_iter = chainer.iterators.SerialIterator(train, args.batchsize)

  test = VOCBboxDataset(year='2007',
                        split='test',
                        use_difficult=True,
                        return_difficult=True)
  test_iter = chainer.iterators.SerialIterator(test,
                                               args.test_batchsize,
                                               repeat=False,
                                               shuffle=False)

  # initial lr is set to 1e-3 by ExponentialShift
  optimizer = chainer.optimizers.MomentumSGD()
  if args.dtype == 'mixed16':
    if not args.nouse_fp32_update:
      print('==> Using FP32 update for dtype=mixed16')
      optimizer.use_fp32_update()  # by default use fp32 update

    # HACK: support skipping update by existing loss scaling functionality
    if args.dynamic_interval is not None:
      optimizer.loss_scaling(interval=args.dynamic_interval, scale=None)
    else:
      optimizer.loss_scaling(interval=float('inf'), scale=None)
      optimizer._loss_scale_max = 1.0  # to prevent actual loss scaling

  optimizer.setup(train_chain)
  for param in train_chain.params():
    if param.name == 'b':
      param.update_rule.add_hook(GradientScaling(2))
    else:
      param.update_rule.add_hook(WeightDecay(0.0005))

  updater = training.updaters.StandardUpdater(train_iter,
                                              optimizer,
                                              device=args.gpu)
  trainer = training.Trainer(updater, (args.iteration, 'iteration'),
                             args.out)
  trainer.extend(extensions.ExponentialShift('lr', 0.1, init=1e-3),
                 trigger=triggers.ManualScheduleTrigger(
                     args.step, 'iteration'))

  trainer.extend(DetectionVOCEvaluator(test_iter,
                                       model,
                                       use_07_metric=True,
                                       label_names=voc_bbox_label_names),
                 trigger=triggers.ManualScheduleTrigger(
                     args.step + [args.iteration], 'iteration'))

  log_interval = 10, 'iteration'
  trainer.extend(extensions.LogReport(trigger=log_interval))
  trainer.extend(extensions.observe_lr(), trigger=log_interval)
  trainer.extend(extensions.observe_value(
      'loss_scale',
      lambda trainer: trainer.updater.get_optimizer('main')._loss_scale),
      trigger=log_interval)

  metrics = [
      'epoch', 'iteration', 'lr', 'main/loss', 'main/loss/loc',
      'main/loss/conf', 'validation/main/map'
  ]
  if args.dynamic_interval is not None:
    metrics.insert(2, 'loss_scale')
  trainer.extend(extensions.PrintReport(metrics), trigger=log_interval)
  trainer.extend(extensions.ProgressBar(update_interval=10))

  trainer.extend(extensions.snapshot(),
                 trigger=triggers.ManualScheduleTrigger(
                     args.step + [args.iteration], 'iteration'))
  trainer.extend(extensions.snapshot_object(
      model, 'model_iter_{.updater.iteration}'),
      trigger=(args.iteration, 'iteration'))

  if args.resume:
    serializers.load_npz(args.resume, trainer)

  hook = AdaLossMonitor(sample_per_n_iter=100,
                        verbose=args.verbose,
                        includes=['Grad', 'Deconvolution'])
  recorder.trainer = trainer
  hook.trainer = trainer

  with ExitStack() as stack:
    stack.enter_context(hook)
    trainer.run()

  recorder.export().to_csv(os.path.join(args.out, 'loss_scale.csv'))
  profiler.export().to_csv(os.path.join(args.out, 'profile.csv'))
  if sanity_checker:
    sanity_checker.export().to_csv(os.path.join(args.out, 'sanity_check.csv'))
  hook.export_history().to_csv(os.path.join(args.out, 'grad_stats.csv'))
Beispiel #22
0
def train(args):
    '''Run training'''
    # seed setting
    torch.manual_seed(args.seed)

    # debug mode setting
    # 0 would be fastest, but 1 seems to be reasonable
    # by considering reproducability
    # revmoe type check
    if args.debugmode < 2:
        chainer.config.type_check = False
        logging.info('torch type check is disabled')
    # use determinisitic computation or not
    if args.debugmode < 1:
        torch.backends.cudnn.deterministic = False
        logging.info('torch cudnn deterministic is disabled')
    else:
        torch.backends.cudnn.deterministic = True

    # check cuda availability
    if not torch.cuda.is_available():
        logging.warning('cuda is not available')

    # get input and output dimension info
    with open(args.valid_label, 'rb') as f:
        valid_json = json.load(f)['utts']
    utts = list(valid_json.keys())
    idim = int(valid_json[utts[0]]['idim'])
    odim = int(valid_json[utts[0]]['odim'])
    logging.info('#input dims : ' + str(idim))
    logging.info('#output dims: ' + str(odim))

    # specify model architecture
    e2e = E2E(idim, odim, args)
    model = Loss(e2e, args.mtlalpha)

    # write model config
    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)
    model_conf = args.outdir + '/model.conf'
    with open(model_conf, 'wb') as f:
        logging.info('writing a model config file to' + model_conf)
        # TODO(watanabe) use others than pickle, possibly json, and save as a text
        pickle.dump((idim, odim, args), f)
    for key in sorted(vars(args).keys()):
        logging.info('ARGS: ' + key + ': ' + str(vars(args)[key]))

    # Set gpu
    reporter = model.reporter
    ngpu = int(args.gpu)
    if ngpu == 1:
        gpu_id = 0
        logging.info('gpu id: ' + str(gpu_id))
        model.cuda()
    elif ngpu > 1:
        gpu_id = range(ngpu)
        logging.info('gpu id: ' + str(gpu_id))
        model = torch.nn.DataParallel(model, device_ids=gpu_id)
        model.cuda()

    # Setup an optimizer
    if args.opt == 'adadelta':
        optimizer = torch.optim.Adadelta(
            model.parameters(), rho=0.95, eps=args.eps)
    elif args.opt == 'adam':
        optimizer = torch.optim.Adam(model.parameters())

    # FIXME: TOO DIRTY HACK
    setattr(optimizer, "target", reporter)
    setattr(optimizer, "serialize", lambda s: reporter.serialize(s))

    # read json data
    with open(args.train_label, 'rb') as f:
        train_json = json.load(f)['utts']
    with open(args.valid_label, 'rb') as f:
        valid_json = json.load(f)['utts']

    # make minibatch list (variable length)
    train = make_batchset(train_json, args.batch_size,
                          args.maxlen_in, args.maxlen_out, args.minibatches)
    valid = make_batchset(valid_json, args.batch_size,
                          args.maxlen_in, args.maxlen_out, args.minibatches)
    # hack to make batchsze argument as 1
    # actual bathsize is included in a list
    train_iter = chainer.iterators.SerialIterator(train, 1)
    valid_iter = chainer.iterators.SerialIterator(
        valid, 1, repeat=False, shuffle=False)

    # prepare Kaldi reader
    train_reader = lazy_io.read_dict_scp(args.train_feat)
    valid_reader = lazy_io.read_dict_scp(args.valid_feat)

    # Set up a trainer
    updater = PytorchSeqUpdaterKaldi(
        model, args.grad_clip, train_iter, optimizer, train_reader, gpu_id)
    trainer = training.Trainer(
        updater, (args.epochs, 'epoch'), out=args.outdir)

    # Resume from a snapshot
    if args.resume:
        raise NotImplementedError
        chainer.serializers.load_npz(args.resume, trainer)

    # Evaluate the model with the test dataset for each epoch
    trainer.extend(PytorchSeqEvaluaterKaldi(
        model, valid_iter, reporter, valid_reader, device=gpu_id))

    # Take a snapshot for each specified epoch
    trainer.extend(extensions.snapshot(), trigger=(1, 'epoch'))

    # Make a plot for training and validation values
    trainer.extend(extensions.PlotReport(['main/loss', 'validation/main/loss',
                                          'main/loss_ctc', 'validation/main/loss_ctc',
                                          'main/loss_att', 'validation/main/loss_att'],
                                         'epoch', file_name='loss.png'))
    trainer.extend(extensions.PlotReport(['main/acc', 'validation/main/acc'],
                                         'epoch', file_name='acc.png'))

    # Save best models
    def torch_save(path, _):
        torch.save(model.state_dict(), path)
        torch.save(model, path + ".pkl")

    trainer.extend(extensions.snapshot_object(model, 'model.loss.best', savefun=torch_save),
                   trigger=training.triggers.MinValueTrigger('validation/main/loss'))
    trainer.extend(extensions.snapshot_object(model, 'model.acc.best', savefun=torch_save),
                   trigger=training.triggers.MaxValueTrigger('validation/main/acc'))

    # epsilon decay in the optimizer
    def torch_load(path, obj):
        model.load_state_dict(torch.load(path))
        return obj
    if args.opt == 'adadelta':
        if args.criterion == 'acc':
            trainer.extend(restore_snapshot(model, args.outdir + '/model.acc.best', load_fn=torch_load),
                           trigger=CompareValueTrigger(
                               'validation/main/acc',
                               lambda best_value, current_value: best_value > current_value))
            trainer.extend(adadelta_eps_decay(args.eps_decay),
                           trigger=CompareValueTrigger(
                               'validation/main/acc',
                               lambda best_value, current_value: best_value > current_value))
        elif args.criterion == 'loss':
            trainer.extend(restore_snapshot(model, args.outdir + '/model.loss.best', load_fn=torch_load),
                           trigger=CompareValueTrigger(
                               'validation/main/loss',
                               lambda best_value, current_value: best_value < current_value))
            trainer.extend(adadelta_eps_decay(args.eps_decay),
                           trigger=CompareValueTrigger(
                               'validation/main/loss',
                               lambda best_value, current_value: best_value < current_value))

    # Write a log of evaluation statistics for each epoch
    trainer.extend(extensions.LogReport(trigger=(100, 'iteration')))
    report_keys = ['epoch', 'iteration', 'main/loss', 'main/loss_ctc', 'main/loss_att',
                   'validation/main/loss', 'validation/main/loss_ctc', 'validation/main/loss_att',
                   'main/acc', 'validation/main/acc', 'elapsed_time']
    if args.opt == 'adadelta':
        trainer.extend(extensions.observe_value(
            'eps', lambda trainer: trainer.updater.get_optimizer('main').param_groups[0]["eps"]),
            trigger=(100, 'iteration'))
        report_keys.append('eps')
    trainer.extend(extensions.PrintReport(
        report_keys), trigger=(100, 'iteration'))

    trainer.extend(extensions.ProgressBar())

    # Run the training
    trainer.run()
Beispiel #23
0
def main():
    parser = argparse.ArgumentParser()
    # general configuration
    parser.add_argument('--gpu',
                        '-g',
                        default='-1',
                        type=str,
                        help='GPU ID (negative value indicates CPU)')
    parser.add_argument('--outdir',
                        type=str,
                        required=True,
                        help='Output directory')
    parser.add_argument('--debugmode', default=1, type=int, help='Debugmode')
    parser.add_argument('--dict', required=True, help='Dictionary')
    parser.add_argument('--seed', default=1, type=int, help='Random seed')
    parser.add_argument('--debugdir',
                        type=str,
                        help='Output directory for debugging')
    parser.add_argument('--resume',
                        '-r',
                        default='',
                        help='Resume the training from snapshot')
    parser.add_argument('--minibatches',
                        '-N',
                        type=int,
                        default='-1',
                        help='Process only N minibatches (for debug)')
    parser.add_argument('--verbose',
                        '-V',
                        default=0,
                        type=int,
                        help='Verbose option')
    # task related
    parser.add_argument('--train-feat',
                        type=str,
                        required=True,
                        help='Filename of train feature data (Kaldi scp)')
    parser.add_argument('--valid-feat',
                        type=str,
                        required=True,
                        help='Filename of validation feature data (Kaldi scp)')
    parser.add_argument('--train-label',
                        type=str,
                        required=True,
                        help='Filename of train label data (json)')
    parser.add_argument('--valid-label',
                        type=str,
                        required=True,
                        help='Filename of validation label data (json)')
    # network archtecture
    # encoder
    parser.add_argument('--etype',
                        default='blstmp',
                        type=str,
                        choices=['blstm', 'blstmp', 'vggblstmp', 'vggblstm'],
                        help='Type of encoder network architecture')
    parser.add_argument('--elayers',
                        default=4,
                        type=int,
                        help='Number of encoder layers')
    parser.add_argument('--eunits',
                        '-u',
                        default=300,
                        type=int,
                        help='Number of encoder hidden units')
    parser.add_argument('--eprojs',
                        default=320,
                        type=int,
                        help='Number of encoder projection units')
    parser.add_argument(
        '--subsample',
        default=1,
        type=str,
        help=
        'Subsample input frames x_y_z means subsample every x frame at 1st layer, '
        'every y frame at 2nd layer etc.')
    # attention
    parser.add_argument('--atype',
                        default='dot',
                        type=str,
                        choices=['dot', 'location', 'noatt'],
                        help='Type of attention architecture')
    parser.add_argument('--adim',
                        default=320,
                        type=int,
                        help='Number of attention transformation dimensions')
    parser.add_argument('--aconv-chans',
                        default=-1,
                        type=int,
                        help='Number of attention convolution channels \
                        (negative value indicates no location-aware attention)'
                        )
    parser.add_argument('--aconv-filts',
                        default=100,
                        type=int,
                        help='Number of attention convolution filters \
                        (negative value indicates no location-aware attention)'
                        )
    # decoder
    parser.add_argument('--dtype',
                        default='lstm',
                        type=str,
                        choices=['lstm'],
                        help='Type of decoder network architecture')
    parser.add_argument('--dlayers',
                        default=1,
                        type=int,
                        help='Number of decoder layers')
    parser.add_argument('--dunits',
                        default=320,
                        type=int,
                        help='Number of decoder hidden units')
    parser.add_argument(
        '--mtlalpha',
        default=0.5,
        type=float,
        help=
        'Multitask learning coefficient, alpha: alpha*ctc_loss + (1-alpha)*att_loss '
    )
    # model (parameter) related
    parser.add_argument('--dropout-rate',
                        default=0.0,
                        type=float,
                        help='Dropout rate')
    # minibatch related
    parser.add_argument('--batch-size',
                        '-b',
                        default=50,
                        type=int,
                        help='Batch size')
    parser.add_argument(
        '--maxlen-in',
        default=800,
        type=int,
        metavar='ML',
        help='Batch size is reduced if the input sequence length > ML')
    parser.add_argument(
        '--maxlen-out',
        default=150,
        type=int,
        metavar='ML',
        help='Batch size is reduced if the output sequence length > ML')
    # optimization related
    parser.add_argument('--opt',
                        default='adadelta',
                        type=str,
                        choices=['adadelta', 'adam'],
                        help='Optimizer')
    parser.add_argument('--eps',
                        default=1e-8,
                        type=float,
                        help='Epsilon constant for optimizer')
    parser.add_argument('--eps-decay',
                        default=0.01,
                        type=float,
                        help='Decaying ratio of epsilon')
    parser.add_argument('--criterion',
                        default='acc',
                        type=str,
                        choices=['loss', 'acc'],
                        help='Criterion to perform epsilon decay')
    parser.add_argument('--threshold',
                        default=1e-4,
                        type=float,
                        help='Threshold to stop iteration')
    parser.add_argument('--epochs',
                        '-e',
                        default=30,
                        type=int,
                        help='Number of maximum epochs')
    parser.add_argument('--grad-clip',
                        default=5,
                        type=float,
                        help='Gradient norm threshold to clip')
    args = parser.parse_args()

    # logging info
    if args.verbose > 0:
        logging.basicConfig(
            level=logging.INFO,
            format=
            '%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s')
    else:
        logging.basicConfig(
            level=logging.WARN,
            format=
            '%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s')
        logging.warning('Skip DEBUG/INFO messages')

    # display PYTHONPATH
    logging.info('python path = ' + os.environ['PYTHONPATH'])

    # display chainer version
    logging.info('chainer version = ' + chainer.__version__)

    # seed setting (chainer seed may not need it)
    nseed = args.seed
    random.seed(nseed)
    np.random.seed(nseed)
    torch.manual_seed(nseed)
    os.environ['CHAINER_SEED'] = str(nseed)
    logging.info('chainer seed = ' + os.environ['CHAINER_SEED'])

    # debug mode setting
    # 0 would be fastest, but 1 seems to be reasonable
    # by considering reproducability
    # revmoe type check
    if args.debugmode < 2:
        chainer.config.type_check = False
        logging.info('chainer type check is disabled')
    # use determinisitic computation or not
    if args.debugmode < 1:
        chainer.config.cudnn_deterministic = False
        torch.backends.cudnn.deterministic = False
        logging.info('chainer cudnn deterministic is disabled')
    else:
        torch.backends.cudnn.deterministic = True
        chainer.config.cudnn_deterministic = True
    # load dictionary for debug log
    if args.dict is not None:
        with open(args.dict, 'rb') as f:
            dictionary = f.readlines()
        char_list = [
            entry.decode('utf-8').split(' ')[0] for entry in dictionary
        ]
        char_list.insert(0, '<blank>')
        char_list.append('<eos>')
        args.char_list = char_list
    else:
        args.char_list = None

    # check cuda and cudnn availability
    if not chainer.cuda.available:
        logging.warning('cuda is not available')
    if not chainer.cuda.cudnn_enabled:
        logging.warning('cudnn is not available')

    # get input and output dimension info
    with open(args.valid_label, 'rb') as f:
        valid_json = json.load(f)['utts']
    utts = list(valid_json.keys())
    idim = int(valid_json[utts[0]]['idim'])
    odim = int(valid_json[utts[0]]['odim'])
    logging.info('#input dims : ' + str(idim))
    logging.info('#output dims: ' + str(odim))

    # specify model architecture
    e2e = E2E(idim, odim, args)
    model = Loss(e2e, args.mtlalpha)

    # write model config
    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)
    model_conf = args.outdir + '/model.conf'
    with open(model_conf, 'wb') as f:
        logging.info('writing a model config file to' + model_conf)
        # TODO(watanabe) use others than pickle, possibly json, and save as a text
        pickle.dump((idim, odim, args), f)
    for key in sorted(vars(args).keys()):
        logging.info('ARGS: ' + key + ': ' + str(vars(args)[key]))

    # Set gpu
    gpu_id = int(args.gpu)
    logging.info('gpu id: ' + str(gpu_id))
    if gpu_id >= 0:
        # Make a specified GPU current
        model.cuda(gpu_id)  # Copy the model to the GPU

    # Setup an optimizer
    if args.opt == 'adadelta':
        optimizer = torch.optim.Adadelta(model.parameters(),
                                         rho=0.95,
                                         eps=args.eps)
    elif args.opt == 'adam':
        optimizer = torch.optim.Adam(model.parameters())

    # FIXME: TOO DIRTY HACK
    setattr(optimizer, "target", model.reporter)
    setattr(optimizer, "serialize", lambda s: model.reporter.serialize(s))

    # read json data
    with open(args.train_label, 'rb') as f:
        train_json = json.load(f)['utts']
    with open(args.valid_label, 'rb') as f:
        valid_json = json.load(f)['utts']

    # make minibatch list (variable length)
    train = make_batchset(train_json, args.batch_size, args.maxlen_in,
                          args.maxlen_out, args.minibatches)
    valid = make_batchset(valid_json, args.batch_size, args.maxlen_in,
                          args.maxlen_out, args.minibatches)
    # hack to make batchsze argument as 1
    # actual bathsize is included in a list
    train_iter = chainer.iterators.SerialIterator(train, 1)
    valid_iter = chainer.iterators.SerialIterator(valid,
                                                  1,
                                                  repeat=False,
                                                  shuffle=False)

    # prepare Kaldi reader
    train_reader = lazy_io.read_dict_scp(args.train_feat)
    valid_reader = lazy_io.read_dict_scp(args.valid_feat)

    # Set up a trainer
    updater = SeqUpdaterKaldi(model, args.grad_clip, train_iter, optimizer,
                              train_reader, gpu_id)
    trainer = training.Trainer(updater, (args.epochs, 'epoch'),
                               out=args.outdir)

    # Resume from a snapshot
    if args.resume:
        raise NotImplementedError
        chainer.serializers.load_npz(args.resume, trainer)

    # Evaluate the model with the test dataset for each epoch
    trainer.extend(
        SeqEvaluaterKaldi(model,
                          valid_iter,
                          model.reporter,
                          valid_reader,
                          device=gpu_id))

    # Take a snapshot for each specified epoch
    trainer.extend(extensions.snapshot(), trigger=(1, 'epoch'))

    # Make a plot for training and validation values
    trainer.extend(
        extensions.PlotReport([
            'main/loss', 'validation/main/loss', 'main/loss_ctc',
            'validation/main/loss_ctc', 'main/loss_att',
            'validation/main/loss_att'
        ],
                              'epoch',
                              file_name='loss.png'))
    trainer.extend(
        extensions.PlotReport(['main/acc', 'validation/main/acc'],
                              'epoch',
                              file_name='acc.png'))

    # Save best models
    def torch_save(path, _):
        torch.save(model.state_dict(), path)
        torch.save(model, path + ".pkl")

    trainer.extend(
        extensions.snapshot_object(model,
                                   'model.loss.best',
                                   savefun=torch_save),
        trigger=training.triggers.MinValueTrigger('validation/main/loss'))
    trainer.extend(
        extensions.snapshot_object(model, 'model.acc.best',
                                   savefun=torch_save),
        trigger=training.triggers.MaxValueTrigger('validation/main/acc'))

    # epsilon decay in the optimizer
    def torch_load(path, obj):
        model.load_state_dict(torch.load(path))
        return obj

    if args.opt == 'adadelta':
        if args.criterion == 'acc':
            trainer.extend(restore_snapshot(model,
                                            args.outdir + '/model.acc.best',
                                            load_fn=torch_load),
                           trigger=CompareValueTrigger(
                               'validation/main/acc', lambda best_value,
                               current_value: best_value > current_value))
            trainer.extend(adadelta_eps_decay(args.eps_decay),
                           trigger=CompareValueTrigger(
                               'validation/main/acc', lambda best_value,
                               current_value: best_value > current_value))
        elif args.criterion == 'loss':
            trainer.extend(restore_snapshot(model,
                                            args.outdir + '/model.loss.best',
                                            load_fn=torch_load),
                           trigger=CompareValueTrigger(
                               'validation/main/loss', lambda best_value,
                               current_value: best_value < current_value))
            trainer.extend(adadelta_eps_decay(args.eps_decay),
                           trigger=CompareValueTrigger(
                               'validation/main/loss', lambda best_value,
                               current_value: best_value < current_value))

    # Write a log of evaluation statistics for each epoch
    trainer.extend(extensions.LogReport(trigger=(100, 'iteration')))
    report_keys = [
        'epoch', 'iteration', 'main/loss', 'main/loss_ctc', 'main/loss_att',
        'validation/main/loss', 'validation/main/loss_ctc',
        'validation/main/loss_att', 'main/acc', 'validation/main/acc',
        'elapsed_time'
    ]
    if args.opt == 'adadelta':
        trainer.extend(extensions.observe_value(
            'eps', lambda trainer: trainer.updater.get_optimizer('main').
            param_groups[0]["eps"]),
                       trigger=(100, 'iteration'))
        report_keys.append('eps')
    trainer.extend(extensions.PrintReport(report_keys),
                   trigger=(100, 'iteration'))

    trainer.extend(extensions.ProgressBar())

    # Run the training
    trainer.run()
Beispiel #24
0
def observe_hyperparam(name):
    def observer(trainer):
        return getattr(trainer.updater.get_optimizer('main'), name)

    return extensions.observe_value(name, observer)
Beispiel #25
0
def train(args):
    """Train with the given args.

    Args:
        args (namespace): The program arguments.

    """
    set_deterministic_pytorch(args)

    # check cuda availability
    if not torch.cuda.is_available():
        logging.warning('cuda is not available')

    # get input and output dimension info
    with open(args.valid_json, 'rb') as f:
        valid_json = json.load(f)['utts']
    utts = list(valid_json.keys())
    idim = int(valid_json[utts[0]]['input'][0]['shape'][-1])
    odim = int(valid_json[utts[0]]['output'][0]['shape'][-1])
    logging.info('#input dims : ' + str(idim))
    logging.info('#output dims: ' + str(odim))

    # specify attention, CTC, hybrid mode
    if args.mtlalpha == 1.0:
        mtl_mode = 'ctc'
        logging.info('Pure CTC mode')
    elif args.mtlalpha == 0.0:
        mtl_mode = 'att'
        logging.info('Pure attention mode')
    else:
        mtl_mode = 'mtl'
        logging.info('Multitask learning mode')

    if args.enc_init is not None or args.dec_init is not None:
        model = load_trained_modules(idim, odim, args)
    elif args.asr_init is not None:
        model, _ = load_trained_model(args.asr_init)
    else:
        model_class = dynamic_import(args.model_module)
        model = model_class(idim, odim, args)
    assert isinstance(model, ASRInterface)

    subsampling_factor = model.subsample[0]

    if args.rnnlm is not None:
        rnnlm_args = get_model_conf(args.rnnlm, args.rnnlm_conf)
        rnnlm = lm_pytorch.ClassifierWithState(
            lm_pytorch.RNNLM(len(args.char_list), rnnlm_args.layer,
                             rnnlm_args.unit))
        torch.load(args.rnnlm, rnnlm)
        model.rnnlm = rnnlm

    # write model config
    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)
    model_conf = args.outdir + '/model.json'
    with open(model_conf, 'wb') as f:
        logging.info('writing a model config file to ' + model_conf)
        f.write(
            json.dumps((idim, odim, vars(args)),
                       indent=4,
                       ensure_ascii=False,
                       sort_keys=True).encode('utf_8'))
    for key in sorted(vars(args).keys()):
        logging.info('ARGS: ' + key + ': ' + str(vars(args)[key]))

    reporter = model.reporter

    # check the use of multi-gpu
    if args.ngpu > 1:
        if args.batch_size != 0:
            logging.info('batch size is automatically increased (%d -> %d)' %
                         (args.batch_size, args.batch_size * args.ngpu))
            args.batch_size *= args.ngpu

    # set torch device
    device = torch.device("cuda" if args.ngpu > 0 else "cpu")
    if args.train_dtype in ("float16", "float32", "float64"):
        dtype = getattr(torch, args.train_dtype)
    else:
        dtype = torch.float32
    logging.info(device)
    logging.info(dtype)
    model = model.to(device=device, dtype=dtype)

    # Setup an optimizer
    if args.opt == 'adadelta':
        optimizer = torch.optim.Adadelta(model.parameters(),
                                         rho=0.95,
                                         eps=args.eps,
                                         weight_decay=args.weight_decay)
    elif args.opt == 'adam':
        optimizer = torch.optim.Adam(model.parameters(),
                                     weight_decay=args.weight_decay)
    elif args.opt == 'noam':
        from espnet.nets.pytorch_backend.rnn.optimizer import get_std_opt
        optimizer = get_std_opt(model, args.adim,
                                args.transformer_warmup_steps,
                                args.transformer_lr)
    else:
        raise NotImplementedError("unknown optimizer: " + args.opt)

    # setup apex.amp
    if args.train_dtype in ("O0", "O1", "O2", "O3"):
        try:
            from apex import amp
        except ImportError as e:
            logging.error(
                f"You need to install apex for --train-dtype {args.train_dtype}. "
                "See https://github.com/NVIDIA/apex#linux")
            raise e
        if args.opt == 'noam':
            model, optimizer.optimizer = amp.initialize(
                model, optimizer.optimizer, opt_level=args.train_dtype)
        else:
            model, optimizer = amp.initialize(model,
                                              optimizer,
                                              opt_level=args.train_dtype)
        use_apex = True
    else:
        use_apex = False

    # FIXME: TOO DIRTY HACK
    setattr(optimizer, "target", reporter)
    setattr(optimizer, "serialize", lambda s: reporter.serialize(s))

    # Setup a converter
    converter = CustomConverter(subsampling_factor=subsampling_factor,
                                dtype=dtype)

    # read json data
    with open(args.train_json, 'rb') as f:
        train_json = json.load(f)['utts']
    with open(args.valid_json, 'rb') as f:
        valid_json = json.load(f)['utts']

    use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0
    # make minibatch list (variable length)
    train = make_batchset(train_json,
                          args.batch_size,
                          args.maxlen_in,
                          args.maxlen_out,
                          args.minibatches,
                          min_batch_size=args.ngpu if args.ngpu > 1 else 1,
                          shortest_first=use_sortagrad,
                          count=args.batch_count,
                          batch_bins=args.batch_bins,
                          batch_frames_in=args.batch_frames_in,
                          batch_frames_out=args.batch_frames_out,
                          batch_frames_inout=args.batch_frames_inout,
                          iaxis=0,
                          oaxis=0)
    valid = make_batchset(valid_json,
                          args.batch_size,
                          args.maxlen_in,
                          args.maxlen_out,
                          args.minibatches,
                          min_batch_size=args.ngpu if args.ngpu > 1 else 1,
                          count=args.batch_count,
                          batch_bins=args.batch_bins,
                          batch_frames_in=args.batch_frames_in,
                          batch_frames_out=args.batch_frames_out,
                          batch_frames_inout=args.batch_frames_inout,
                          iaxis=0,
                          oaxis=0)

    load_tr = LoadInputsAndTargets(
        mode='asr',
        load_output=True,
        preprocess_conf=args.preprocess_conf,
        preprocess_args={'train': True}  # Switch the mode of preprocessing
    )
    load_cv = LoadInputsAndTargets(
        mode='asr',
        load_output=True,
        preprocess_conf=args.preprocess_conf,
        preprocess_args={'train': False}  # Switch the mode of preprocessing
    )
    # hack to make batchsize argument as 1
    # actual bathsize is included in a list
    if args.n_iter_processes > 0:
        train_iter = ToggleableShufflingMultiprocessIterator(
            TransformDataset(train, load_tr),
            batch_size=1,
            n_processes=args.n_iter_processes,
            n_prefetch=8,
            maxtasksperchild=20,
            shuffle=not use_sortagrad)
        valid_iter = ToggleableShufflingMultiprocessIterator(
            TransformDataset(valid, load_cv),
            batch_size=1,
            repeat=False,
            shuffle=False,
            n_processes=args.n_iter_processes,
            n_prefetch=8,
            maxtasksperchild=20)
    else:
        train_iter = ToggleableShufflingSerialIterator(
            TransformDataset(train, load_tr),
            batch_size=1,
            shuffle=not use_sortagrad)
        valid_iter = ToggleableShufflingSerialIterator(TransformDataset(
            valid, load_cv),
                                                       batch_size=1,
                                                       repeat=False,
                                                       shuffle=False)

    # Set up a trainer
    updater = CustomUpdater(model,
                            args.grad_clip,
                            train_iter,
                            optimizer,
                            converter,
                            device,
                            args.ngpu,
                            args.grad_noise,
                            args.accum_grad,
                            use_apex=use_apex)
    trainer = training.Trainer(updater, (args.epochs, 'epoch'),
                               out=args.outdir)

    if use_sortagrad:
        trainer.extend(
            ShufflingEnabler([train_iter]),
            trigger=(args.sortagrad if args.sortagrad != -1 else args.epochs,
                     'epoch'))

    # Resume from a snapshot
    if args.resume:
        logging.info('resumed from %s' % args.resume)
        torch_resume(args.resume, trainer)

    # Evaluate the model with the test dataset for each epoch
    trainer.extend(
        CustomEvaluator(model, valid_iter, reporter, converter, device,
                        args.ngpu))

    # Save attention weight each epoch
    if args.num_save_attention > 0 and args.mtlalpha != 1.0:
        data = sorted(list(valid_json.items())[:args.num_save_attention],
                      key=lambda x: int(x[1]['input'][0]['shape'][1]),
                      reverse=True)
        if hasattr(model, "module"):
            att_vis_fn = model.module.calculate_all_attentions
            plot_class = model.module.attention_plot_class
        else:
            att_vis_fn = model.calculate_all_attentions
            plot_class = model.attention_plot_class
        att_reporter = plot_class(att_vis_fn,
                                  data,
                                  args.outdir + "/att_ws",
                                  converter=converter,
                                  transform=load_cv,
                                  device=device)
        trainer.extend(att_reporter, trigger=(1, 'epoch'))
    else:
        att_reporter = None

    # Make a plot for training and validation values
    trainer.extend(
        extensions.PlotReport([
            'main/loss', 'validation/main/loss', 'main/loss_ctc',
            'validation/main/loss_ctc', 'main/loss_att',
            'validation/main/loss_att'
        ],
                              'epoch',
                              file_name='loss.png'))
    trainer.extend(
        extensions.PlotReport(['main/acc', 'validation/main/acc'],
                              'epoch',
                              file_name='acc.png'))
    trainer.extend(
        extensions.PlotReport(['main/cer_ctc', 'validation/main/cer_ctc'],
                              'epoch',
                              file_name='cer.png'))

    # Save best models
    trainer.extend(
        snapshot_object(model, 'model.loss.best'),
        trigger=training.triggers.MinValueTrigger('validation/main/loss'))
    if mtl_mode != 'ctc':
        trainer.extend(
            snapshot_object(model, 'model.acc.best'),
            trigger=training.triggers.MaxValueTrigger('validation/main/acc'))

    # save snapshot which contains model and optimizer states
    trainer.extend(torch_snapshot(), trigger=(1, 'epoch'))

    # epsilon decay in the optimizer
    if args.opt == 'adadelta':
        if args.criterion == 'acc' and mtl_mode != 'ctc':
            trainer.extend(restore_snapshot(model,
                                            args.outdir + '/model.acc.best',
                                            load_fn=torch_load),
                           trigger=CompareValueTrigger(
                               'validation/main/acc', lambda best_value,
                               current_value: best_value > current_value))
            trainer.extend(adadelta_eps_decay(args.eps_decay),
                           trigger=CompareValueTrigger(
                               'validation/main/acc', lambda best_value,
                               current_value: best_value > current_value))
        elif args.criterion == 'loss':
            trainer.extend(restore_snapshot(model,
                                            args.outdir + '/model.loss.best',
                                            load_fn=torch_load),
                           trigger=CompareValueTrigger(
                               'validation/main/loss', lambda best_value,
                               current_value: best_value < current_value))
            trainer.extend(adadelta_eps_decay(args.eps_decay),
                           trigger=CompareValueTrigger(
                               'validation/main/loss', lambda best_value,
                               current_value: best_value < current_value))

    # Write a log of evaluation statistics for each epoch
    trainer.extend(
        extensions.LogReport(trigger=(args.report_interval_iters,
                                      'iteration')))
    report_keys = [
        'epoch', 'iteration', 'main/loss', 'main/loss_ctc', 'main/loss_att',
        'validation/main/loss', 'validation/main/loss_ctc',
        'validation/main/loss_att', 'main/acc', 'validation/main/acc',
        'main/cer_ctc', 'validation/main/cer_ctc', 'elapsed_time'
    ]
    if args.opt == 'adadelta':
        trainer.extend(extensions.observe_value(
            'eps', lambda trainer: trainer.updater.get_optimizer('main').
            param_groups[0]["eps"]),
                       trigger=(args.report_interval_iters, 'iteration'))
        report_keys.append('eps')
    if args.report_cer:
        report_keys.append('validation/main/cer')
    if args.report_wer:
        report_keys.append('validation/main/wer')
    trainer.extend(extensions.PrintReport(report_keys),
                   trigger=(args.report_interval_iters, 'iteration'))

    trainer.extend(
        extensions.ProgressBar(update_interval=args.report_interval_iters))
    set_early_stop(trainer, args)

    if args.tensorboard_dir is not None and args.tensorboard_dir != "":
        trainer.extend(TensorboardLogger(SummaryWriter(args.tensorboard_dir),
                                         att_reporter),
                       trigger=(args.report_interval_iters, "iteration"))
    # Run the training
    trainer.run()
    check_early_stop(trainer, args.epochs)
Beispiel #26
0
def train(args):
    with open(args.config, "r") as cfg:
        data_config = json.load(cfg)

    size = data_config["train"]["kwargs"]["size"]

    # set up GPUs if necessary
    device = -1
    comm = None

    if args.device[0] < 0:
        print("training on CPU (not recommended!)")
    else:
        if len(args.device) > 1:
            import chainermn
            comm = chainermn.create_communicator()
            device = args.device[comm.intra_rank]

            print("using multi-gpu training with GPU {}".format(device))
        elif args.device[0] >= 0:
            device = args.device[0]
            print("using single gpu training with GPU {}".format(device))

    disc, gen = setup_models(args, size, device)

    def setup_optimizer(opt, model, comm=None):
        if comm is not None:
            opt = chainermn.create_multi_node_optimizer(opt, comm)
        opt.setup(model)
        return opt

    opt_disc = setup_optimizer(
        chainer.optimizers.Adam(alpha=args.learning_rate), disc, comm)
    opt_gen = setup_optimizer(
        chainer.optimizers.Adam(alpha=args.learning_rate), gen, comm)

    # pretraining of the global generator needs half-size images
    if comm is None or comm.rank == 0:
        train_d = getattr(pix2pixHD, data_config["class_name"])(
            *data_config["train"]["args"],
            **data_config["train"]["kwargs"],
            one_hot=args.no_one_hot)
        test_d = getattr(pix2pixHD, data_config["class_name"])(
            *data_config["test"]["args"],
            **data_config["test"]["kwargs"],
            one_hot=args.no_one_hot,
            random_flip=False)
    else:
        train_d, test_d = None, None

    if comm is not None:
        train_d = chainermn.scatter_dataset(train_d, comm)
        test_d = chainermn.scatter_dataset(test_d, comm)
        multiprocessing.set_start_method('forkserver')

    train_iter = chainer.iterators.MultiprocessIterator(train_d,
                                                        args.batchsize,
                                                        n_processes=2)
    test_iter = chainer.iterators.SerialIterator(test_d,
                                                 args.batchsize,
                                                 shuffle=False)

    iterators = {'main': train_iter, 'test': test_iter}
    optimizers = {'discriminator': opt_disc, 'generator': opt_gen}

    updater = Pix2pixHDUpdater(iterators, optimizers, device=device)

    trainer = training.Trainer(updater, (args.epochs, 'epoch'),
                               out=args.output)

    if comm is None or comm.rank == 0:
        trigger = (100, "iteration")
        if comm is None:
            # this is a hack... sorry
            trainer.extend(train_d.visualizer(n=args.num_vis,
                                              one_hot=args.no_one_hot),
                           trigger=trigger)  #(1, "epoch"))
        else:
            trainer.extend(train_d._dataset.visualizer(
                n=args.num_vis, one_hot=args.no_one_hot),
                           trigger=trigger)
        trainer.extend(extensions.LogReport(trigger=(10, "iteration")))
        trainer.extend(extensions.PrintReport([
            'epoch', 'iteration', 'Dloss_real', 'Dloss_fake', 'Gloss',
            "feat_loss", "lr"
        ]),
                       trigger=(10, "iteration"))
        trainer.extend(extensions.ProgressBar(update_interval=10))

        trainer.extend(
            extensions.snapshot(filename='snapshot_epoch_{.updater.epoch}'),
            trigger=(args.epochs // 10, "epoch"))
        trainer.extend(extensions.snapshot_object(
            gen, 'generator_model_epoch_{.updater.epoch}'),
                       trigger=(args.epochs // 10, "epoch"))

    # decay the learning rate from halfway through training
    trainer.extend(extensions.LinearShift("alpha",
                                          value_range=(args.learning_rate,
                                                       0.0),
                                          time_range=(args.epochs // 2,
                                                      args.epochs),
                                          optimizer=opt_disc),
                   trigger=(1, "epoch"))
    trainer.extend(extensions.LinearShift("alpha",
                                          value_range=(args.learning_rate,
                                                       0.0),
                                          time_range=(args.epochs // 2,
                                                      args.epochs),
                                          optimizer=opt_gen),
                   trigger=(1, "epoch"))

    trainer.extend(extensions.observe_value(
        "lr",
        lambda trainer: trainer.updater.get_optimizer("discriminator").lr),
                   trigger=(10, "iteration"))

    if args.resume:
        chainer.serializers.load_npz(args.resume, trainer)

    trainer.run()
Beispiel #27
0
def train(args):
    '''Run training'''
    # seed setting
    torch.manual_seed(args.seed)

    # debug mode setting
    # 0 would be fastest, but 1 seems to be reasonable
    # by considering reproducability
    # revmoe type check
    if args.debugmode < 2:
        chainer.config.type_check = False
        logging.info('torch type check is disabled')
    # use determinisitic computation or not
    if args.debugmode < 1:
        torch.backends.cudnn.deterministic = False
        logging.info('torch cudnn deterministic is disabled')
    else:
        torch.backends.cudnn.deterministic = True

    # check cuda availability
    if not torch.cuda.is_available():
        logging.warning('cuda is not available')

    # get input and output dimension info
    with open(args.valid_json, 'rb') as f:
        valid_json = json.load(f)['utts']
    utts = list(valid_json.keys())
    idim = int(valid_json[utts[0]]['input'][0]['shape'][1])
    odim = int(valid_json[utts[0]]['output'][0]['shape'][1])
    logging.info('#input dims : ' + str(idim))
    logging.info('#output dims: ' + str(odim))

    # specify attention, CTC, hybrid mode
    if args.mtlalpha == 1.0:
        mtl_mode = 'ctc'
        logging.info('Pure CTC mode')
    elif args.mtlalpha == 0.0:
        mtl_mode = 'att'
        logging.info('Pure attention mode')
    else:
        mtl_mode = 'mtl'
        logging.info('Multitask learning mode')

    # specify model architecture
    e2e = E2E(idim, odim, args)
    model = Loss(e2e, args.mtlalpha)

    # write model config
    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)
    model_conf = args.outdir + '/model.json'
    with open(model_conf, 'wb') as f:
        logging.info('writing a model config file to ' + model_conf)
        f.write(
            json.dumps((idim, odim, vars(args)), indent=4,
                       sort_keys=True).encode('utf_8'))
    for key in sorted(vars(args).keys()):
        logging.info('ARGS: ' + key + ': ' + str(vars(args)[key]))

    reporter = model.reporter

    # check the use of multi-gpu
    if args.ngpu > 1:
        model = torch.nn.DataParallel(model, device_ids=list(range(args.ngpu)))
        logging.info('batch size is automatically increased (%d -> %d)' %
                     (args.batch_size, args.batch_size * args.ngpu))
        args.batch_size *= args.ngpu

    # set torch device
    device = torch.device("cuda" if args.ngpu > 0 else "cpu")
    model = model.to(device)

    # Setup an optimizer
    if args.opt == 'adadelta':
        optimizer = torch.optim.Adadelta(model.parameters(),
                                         rho=0.95,
                                         eps=args.eps)
    elif args.opt == 'adam':
        optimizer = torch.optim.Adam(model.parameters())

    # FIXME: TOO DIRTY HACK
    setattr(optimizer, "target", reporter)
    setattr(optimizer, "serialize", lambda s: reporter.serialize(s))

    # Setup a converter
    converter = CustomConverter(e2e.subsample[0])

    # read json data
    with open(args.train_json, 'rb') as f:
        train_json = json.load(f)['utts']
    with open(args.valid_json, 'rb') as f:
        valid_json = json.load(f)['utts']

    # make minibatch list (variable length)
    train = make_batchset(train_json, args.batch_size, args.maxlen_in,
                          args.maxlen_out, args.minibatches)
    valid = make_batchset(valid_json, args.batch_size, args.maxlen_in,
                          args.maxlen_out, args.minibatches)
    # hack to make batchsze argument as 1
    # actual bathsize is included in a list
    train_iter = chainer.iterators.MultiprocessIterator(
        TransformDataset(train, converter.transform),
        batch_size=1,
        n_processes=1,
        n_prefetch=8)  #, maxtasksperchild=20)
    valid_iter = chainer.iterators.SerialIterator(TransformDataset(
        valid, converter.transform),
                                                  batch_size=1,
                                                  repeat=False,
                                                  shuffle=False)

    # Set up a trainer
    updater = CustomUpdater(model, args.grad_clip, train_iter, optimizer,
                            converter, device, args.ngpu)
    trainer = training.Trainer(updater, (args.epochs, 'epoch'),
                               out=args.outdir)

    # Resume from a snapshot
    if args.resume:
        logging.info('resumed from %s' % args.resume)
        torch_resume(args.resume, trainer)

    # Evaluate the model with the test dataset for each epoch
    trainer.extend(
        CustomEvaluator(model, valid_iter, reporter, converter, device))

    # Save attention weight each epoch
    if args.num_save_attention > 0 and args.mtlalpha != 1.0:
        data = sorted(list(valid_json.items())[:args.num_save_attention],
                      key=lambda x: int(x[1]['input'][0]['shape'][1]),
                      reverse=True)
        if hasattr(model, "module"):
            att_vis_fn = model.module.predictor.calculate_all_attentions
        else:
            att_vis_fn = model.predictor.calculate_all_attentions
        trainer.extend(PlotAttentionReport(att_vis_fn,
                                           data,
                                           args.outdir + "/att_ws",
                                           converter=converter,
                                           device=device),
                       trigger=(1, 'epoch'))

    # Make a plot for training and validation values
    trainer.extend(
        extensions.PlotReport([
            'main/loss', 'validation/main/loss', 'main/loss_ctc',
            'validation/main/loss_ctc', 'main/loss_att',
            'validation/main/loss_att'
        ],
                              'epoch',
                              file_name='loss.png'))
    trainer.extend(
        extensions.PlotReport(['main/acc', 'validation/main/acc'],
                              'epoch',
                              file_name='acc.png'))

    # Save best models
    trainer.extend(
        extensions.snapshot_object(model,
                                   'model.loss.best',
                                   savefun=torch_save),
        trigger=training.triggers.MinValueTrigger('validation/main/loss'))
    if mtl_mode is not 'ctc':
        trainer.extend(
            extensions.snapshot_object(model,
                                       'model.acc.best',
                                       savefun=torch_save),
            trigger=training.triggers.MaxValueTrigger('validation/main/acc'))

    # save snapshot which contains model and optimizer states
    trainer.extend(torch_snapshot(), trigger=(1, 'epoch'))

    # epsilon decay in the optimizer
    if args.opt == 'adadelta':
        if args.criterion == 'acc' and mtl_mode is not 'ctc':
            trainer.extend(restore_snapshot(model,
                                            args.outdir + '/model.acc.best',
                                            load_fn=torch_load),
                           trigger=CompareValueTrigger(
                               'validation/main/acc', lambda best_value,
                               current_value: best_value > current_value))
            trainer.extend(adadelta_eps_decay(args.eps_decay),
                           trigger=CompareValueTrigger(
                               'validation/main/acc', lambda best_value,
                               current_value: best_value > current_value))
        elif args.criterion == 'loss':
            trainer.extend(restore_snapshot(model,
                                            args.outdir + '/model.loss.best',
                                            load_fn=torch_load),
                           trigger=CompareValueTrigger(
                               'validation/main/loss', lambda best_value,
                               current_value: best_value < current_value))
            trainer.extend(adadelta_eps_decay(args.eps_decay),
                           trigger=CompareValueTrigger(
                               'validation/main/loss', lambda best_value,
                               current_value: best_value < current_value))

    # Write a log of evaluation statistics for each epoch
    trainer.extend(extensions.LogReport(trigger=(REPORT_INTERVAL,
                                                 'iteration')))
    report_keys = [
        'epoch', 'iteration', 'main/loss', 'main/loss_ctc', 'main/loss_att',
        'validation/main/loss', 'validation/main/loss_ctc',
        'validation/main/loss_att', 'main/acc', 'validation/main/acc',
        'elapsed_time'
    ]
    if args.opt == 'adadelta':
        trainer.extend(extensions.observe_value(
            'eps', lambda trainer: trainer.updater.get_optimizer('main').
            param_groups[0]["eps"]),
                       trigger=(REPORT_INTERVAL, 'iteration'))
        report_keys.append('eps')
    trainer.extend(extensions.PrintReport(report_keys),
                   trigger=(REPORT_INTERVAL, 'iteration'))

    trainer.extend(extensions.ProgressBar(update_interval=REPORT_INTERVAL))

    # Run the training
    trainer.run()
Beispiel #28
0
def main(args):

    assert ((args.depth - args.block - 1) % args.block == 0)
    n_layer = (args.depth - args.block - 1) / args.block
    if args.dataset == 'cifar10':
        train, test = cifar.get_cifar10()
        n_class = 10
    elif args.dataset == 'cifar100':
        train, test = cifar.get_cifar100()
        n_class = 100
    elif args.dataset == 'SVHN':
        raise NotImplementedError()

    mean = numpy.zeros((3, 32, 32), dtype=numpy.float32)
    for image, _ in train:
        mean += image / len(train)

    train = PreprocessedDataset(train, mean, random=True)
    test = PreprocessedDataset(test, mean)

    train_iter = chainer.iterators.MultiprocessIterator(train, args.batchsize)
    test_iter = chainer.iterators.MultiprocessIterator(test,
                                                       args.batchsize,
                                                       repeat=False,
                                                       shuffle=False)

    model = chainer.links.Classifier(
        DenseNet(n_layer, args.growth_rate, n_class, args.drop_ratio, 16,
                 args.block))
    if args.init_model:
        serializers.load_npz(args.init_model, model)

    optimizer = chainer.optimizers.MomentumSGD(lr=args.lr / len(args.gpus),
                                               momentum=0.9)
    optimizer.setup(model)
    optimizer.add_hook(chainer.optimizer.WeightDecay(args.weight_decay))

    devices = {'main': args.gpus[0]}
    if len(args.gpus) > 2:
        for gid in args.gpus[1:]:
            devices['gpu%d' % gid] = gid
    updater = training.ParallelUpdater(train_iter, optimizer, devices=devices)
    trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.dir)

    val_interval = (1, 'epoch')
    log_interval = (1, 'epoch')

    eval_model = model.copy()
    eval_model.train = False

    trainer.extend(extensions.Evaluator(test_iter,
                                        eval_model,
                                        device=args.gpus[0]),
                   trigger=val_interval)
    trainer.extend(extensions.ExponentialShift('lr', args.lr_decay_ratio),
                   trigger=(args.lr_decay_freq, 'epoch'))
    trainer.extend(extensions.dump_graph('main/loss'))
    trainer.extend(extensions.snapshot_object(model,
                                              'epoch_{.updater.epoch}.model'),
                   trigger=val_interval)
    trainer.extend(extensions.snapshot_object(optimizer,
                                              'epoch_{.updater.epoch}.state'),
                   trigger=val_interval)
    trainer.extend(extensions.LogReport(trigger=log_interval))
    trainer.extend(extensions.observe_lr(), trigger=log_interval)
    start_time = time.time()
    trainer.extend(extensions.observe_value(
        'time', lambda _: time.time() - start_time),
                   trigger=log_interval)
    trainer.extend(extensions.PrintReport([
        'time',
        'epoch',
        'iteration',
        'main/loss',
        'validation/main/loss',
        'main/accuracy',
        'validation/main/accuracy',
        'lr',
    ]),
                   trigger=log_interval)
    trainer.extend(extensions.observe_value('graph',
                                            lambda _: create_fig(args.dir)),
                   trigger=(2, 'epoch'))
    trainer.extend(extensions.ProgressBar(update_interval=10))

    trainer.run()
Beispiel #29
0
def train(args):
    """Train with the given args.

    Args:
        args (namespace): The program arguments.

    """
    set_deterministic_pytorch(args)

    # check cuda availability
    if not torch.cuda.is_available():
        logging.warning("cuda is not available")

    # get input and output dimension info
    with open(args.valid_json, "rb") as f:
        valid_json = json.load(f)["utts"]
    utts = list(valid_json.keys())
    idim = int(valid_json[utts[0]]["output"][1]["shape"][1])
    odim = int(valid_json[utts[0]]["output"][0]["shape"][1])
    logging.info("#input dims : " + str(idim))
    logging.info("#output dims: " + str(odim))

    # specify model architecture
    model_class = dynamic_import(args.model_module)
    model = model_class(idim, odim, args)
    assert isinstance(model, MTInterface)

    # write model config
    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)
    model_conf = args.outdir + "/model.json"
    with open(model_conf, "wb") as f:
        logging.info("writing a model config file to " + model_conf)
        f.write(
            json.dumps((idim, odim, vars(args)),
                       indent=4,
                       ensure_ascii=False,
                       sort_keys=True).encode("utf_8"))
    for key in sorted(vars(args).keys()):
        logging.info("ARGS: " + key + ": " + str(vars(args)[key]))

    reporter = model.reporter

    # check the use of multi-gpu
    if args.ngpu > 1:
        if args.batch_size != 0:
            logging.warning(
                "batch size is automatically increased (%d -> %d)" %
                (args.batch_size, args.batch_size * args.ngpu))
            args.batch_size *= args.ngpu

    # set torch device
    device = torch.device("cuda" if args.ngpu > 0 else "cpu")
    if args.train_dtype in ("float16", "float32", "float64"):
        dtype = getattr(torch, args.train_dtype)
    else:
        dtype = torch.float32
    model = model.to(device=device, dtype=dtype)

    logging.warning(
        "num. model params: {:,} (num. trained: {:,} ({:.1f}%))".format(
            sum(p.numel() for p in model.parameters()),
            sum(p.numel() for p in model.parameters() if p.requires_grad),
            sum(p.numel() for p in model.parameters() if p.requires_grad) *
            100.0 / sum(p.numel() for p in model.parameters()),
        ))

    # Setup an optimizer
    if args.opt == "adadelta":
        optimizer = torch.optim.Adadelta(model.parameters(),
                                         rho=0.95,
                                         eps=args.eps,
                                         weight_decay=args.weight_decay)
    elif args.opt == "adam":
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=args.lr,
                                     weight_decay=args.weight_decay)
    elif args.opt == "noam":
        from espnet.nets.pytorch_backend.transformer.optimizer import get_std_opt

        optimizer = get_std_opt(
            model.parameters(),
            args.adim,
            args.transformer_warmup_steps,
            args.transformer_lr,
        )
    else:
        raise NotImplementedError("unknown optimizer: " + args.opt)

    # setup apex.amp
    if args.train_dtype in ("O0", "O1", "O2", "O3"):
        try:
            from apex import amp
        except ImportError as e:
            logging.error(
                f"You need to install apex for --train-dtype {args.train_dtype}. "
                "See https://github.com/NVIDIA/apex#linux")
            raise e
        if args.opt == "noam":
            model, optimizer.optimizer = amp.initialize(
                model, optimizer.optimizer, opt_level=args.train_dtype)
        else:
            model, optimizer = amp.initialize(model,
                                              optimizer,
                                              opt_level=args.train_dtype)
        use_apex = True
    else:
        use_apex = False

    # FIXME: TOO DIRTY HACK
    setattr(optimizer, "target", reporter)
    setattr(optimizer, "serialize", lambda s: reporter.serialize(s))

    # Setup a converter
    converter = CustomConverter()

    # read json data
    with open(args.train_json, "rb") as f:
        train_json = json.load(f)["utts"]
    with open(args.valid_json, "rb") as f:
        valid_json = json.load(f)["utts"]

    use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0
    # make minibatch list (variable length)
    train = make_batchset(
        train_json,
        args.batch_size,
        args.maxlen_in,
        args.maxlen_out,
        args.minibatches,
        min_batch_size=args.ngpu if args.ngpu > 1 else 1,
        shortest_first=use_sortagrad,
        count=args.batch_count,
        batch_bins=args.batch_bins,
        batch_frames_in=args.batch_frames_in,
        batch_frames_out=args.batch_frames_out,
        batch_frames_inout=args.batch_frames_inout,
        mt=True,
        iaxis=1,
        oaxis=0,
    )
    valid = make_batchset(
        valid_json,
        args.batch_size,
        args.maxlen_in,
        args.maxlen_out,
        args.minibatches,
        min_batch_size=args.ngpu if args.ngpu > 1 else 1,
        count=args.batch_count,
        batch_bins=args.batch_bins,
        batch_frames_in=args.batch_frames_in,
        batch_frames_out=args.batch_frames_out,
        batch_frames_inout=args.batch_frames_inout,
        mt=True,
        iaxis=1,
        oaxis=0,
    )

    load_tr = LoadInputsAndTargets(mode="mt", load_output=True)
    load_cv = LoadInputsAndTargets(mode="mt", load_output=True)
    # hack to make batchsize argument as 1
    # actual bathsize is included in a list
    # default collate function converts numpy array to pytorch tensor
    # we used an empty collate function instead which returns list
    train_iter = ChainerDataLoader(
        dataset=TransformDataset(train,
                                 lambda data: converter([load_tr(data)])),
        batch_size=1,
        num_workers=args.n_iter_processes,
        shuffle=not use_sortagrad,
        collate_fn=lambda x: x[0],
    )
    valid_iter = ChainerDataLoader(
        dataset=TransformDataset(valid,
                                 lambda data: converter([load_cv(data)])),
        batch_size=1,
        shuffle=False,
        collate_fn=lambda x: x[0],
        num_workers=args.n_iter_processes,
    )

    # Set up a trainer
    updater = CustomUpdater(
        model,
        args.grad_clip,
        {"main": train_iter},
        optimizer,
        device,
        args.ngpu,
        False,
        args.accum_grad,
        use_apex=use_apex,
    )
    trainer = training.Trainer(updater, (args.epochs, "epoch"),
                               out=args.outdir)

    if use_sortagrad:
        trainer.extend(
            ShufflingEnabler([train_iter]),
            trigger=(args.sortagrad if args.sortagrad != -1 else args.epochs,
                     "epoch"),
        )

    # Resume from a snapshot
    if args.resume:
        logging.info("resumed from %s" % args.resume)
        torch_resume(args.resume, trainer)

    # Evaluate the model with the test dataset for each epoch
    if args.save_interval_iters > 0:
        trainer.extend(
            CustomEvaluator(model, {"main": valid_iter}, reporter, device,
                            args.ngpu),
            trigger=(args.save_interval_iters, "iteration"),
        )
    else:
        trainer.extend(
            CustomEvaluator(model, {"main": valid_iter}, reporter, device,
                            args.ngpu))

    # Save attention weight each epoch
    if args.num_save_attention > 0:
        # NOTE: sort it by output lengths
        data = sorted(
            list(valid_json.items())[:args.num_save_attention],
            key=lambda x: int(x[1]["output"][0]["shape"][0]),
            reverse=True,
        )
        if hasattr(model, "module"):
            att_vis_fn = model.module.calculate_all_attentions
            plot_class = model.module.attention_plot_class
        else:
            att_vis_fn = model.calculate_all_attentions
            plot_class = model.attention_plot_class
        att_reporter = plot_class(
            att_vis_fn,
            data,
            args.outdir + "/att_ws",
            converter=converter,
            transform=load_cv,
            device=device,
            ikey="output",
            iaxis=1,
        )
        trainer.extend(att_reporter, trigger=(1, "epoch"))
    else:
        att_reporter = None

    # Make a plot for training and validation values
    trainer.extend(
        extensions.PlotReport(["main/loss", "validation/main/loss"],
                              "epoch",
                              file_name="loss.png"))
    trainer.extend(
        extensions.PlotReport(["main/acc", "validation/main/acc"],
                              "epoch",
                              file_name="acc.png"))
    trainer.extend(
        extensions.PlotReport(["main/ppl", "validation/main/ppl"],
                              "epoch",
                              file_name="ppl.png"))
    trainer.extend(
        extensions.PlotReport(["main/bleu", "validation/main/bleu"],
                              "epoch",
                              file_name="bleu.png"))

    # Save best models
    trainer.extend(
        snapshot_object(model, "model.loss.best"),
        trigger=training.triggers.MinValueTrigger("validation/main/loss"),
    )
    trainer.extend(
        snapshot_object(model, "model.acc.best"),
        trigger=training.triggers.MaxValueTrigger("validation/main/acc"),
    )

    # save snapshot which contains model and optimizer states
    if args.save_interval_iters > 0:
        trainer.extend(
            torch_snapshot(filename="snapshot.iter.{.updater.iteration}"),
            trigger=(args.save_interval_iters, "iteration"),
        )
    else:
        trainer.extend(torch_snapshot(), trigger=(1, "epoch"))

    # epsilon decay in the optimizer
    if args.opt == "adadelta":
        if args.criterion == "acc":
            trainer.extend(
                restore_snapshot(model,
                                 args.outdir + "/model.acc.best",
                                 load_fn=torch_load),
                trigger=CompareValueTrigger(
                    "validation/main/acc",
                    lambda best_value, current_value: best_value >
                    current_value,
                ),
            )
            trainer.extend(
                adadelta_eps_decay(args.eps_decay),
                trigger=CompareValueTrigger(
                    "validation/main/acc",
                    lambda best_value, current_value: best_value >
                    current_value,
                ),
            )
        elif args.criterion == "loss":
            trainer.extend(
                restore_snapshot(model,
                                 args.outdir + "/model.loss.best",
                                 load_fn=torch_load),
                trigger=CompareValueTrigger(
                    "validation/main/loss",
                    lambda best_value, current_value: best_value <
                    current_value,
                ),
            )
            trainer.extend(
                adadelta_eps_decay(args.eps_decay),
                trigger=CompareValueTrigger(
                    "validation/main/loss",
                    lambda best_value, current_value: best_value <
                    current_value,
                ),
            )
    elif args.opt == "adam":
        if args.criterion == "acc":
            trainer.extend(
                restore_snapshot(model,
                                 args.outdir + "/model.acc.best",
                                 load_fn=torch_load),
                trigger=CompareValueTrigger(
                    "validation/main/acc",
                    lambda best_value, current_value: best_value >
                    current_value,
                ),
            )
            trainer.extend(
                adam_lr_decay(args.lr_decay),
                trigger=CompareValueTrigger(
                    "validation/main/acc",
                    lambda best_value, current_value: best_value >
                    current_value,
                ),
            )
        elif args.criterion == "loss":
            trainer.extend(
                restore_snapshot(model,
                                 args.outdir + "/model.loss.best",
                                 load_fn=torch_load),
                trigger=CompareValueTrigger(
                    "validation/main/loss",
                    lambda best_value, current_value: best_value <
                    current_value,
                ),
            )
            trainer.extend(
                adam_lr_decay(args.lr_decay),
                trigger=CompareValueTrigger(
                    "validation/main/loss",
                    lambda best_value, current_value: best_value <
                    current_value,
                ),
            )

    # Write a log of evaluation statistics for each epoch
    trainer.extend(
        extensions.LogReport(trigger=(args.report_interval_iters,
                                      "iteration")))
    report_keys = [
        "epoch",
        "iteration",
        "main/loss",
        "validation/main/loss",
        "main/acc",
        "validation/main/acc",
        "main/ppl",
        "validation/main/ppl",
        "elapsed_time",
    ]
    if args.opt == "adadelta":
        trainer.extend(
            extensions.observe_value(
                "eps",
                lambda trainer: trainer.updater.get_optimizer("main").
                param_groups[0]["eps"],
            ),
            trigger=(args.report_interval_iters, "iteration"),
        )
        report_keys.append("eps")
    elif args.opt in ["adam", "noam"]:
        trainer.extend(
            extensions.observe_value(
                "lr",
                lambda trainer: trainer.updater.get_optimizer("main").
                param_groups[0]["lr"],
            ),
            trigger=(args.report_interval_iters, "iteration"),
        )
        report_keys.append("lr")
    if args.report_bleu:
        report_keys.append("main/bleu")
        report_keys.append("validation/main/bleu")
    trainer.extend(
        extensions.PrintReport(report_keys),
        trigger=(args.report_interval_iters, "iteration"),
    )

    trainer.extend(
        extensions.ProgressBar(update_interval=args.report_interval_iters))
    set_early_stop(trainer, args)

    if args.tensorboard_dir is not None and args.tensorboard_dir != "":
        from torch.utils.tensorboard import SummaryWriter

        trainer.extend(
            TensorboardLogger(SummaryWriter(args.tensorboard_dir),
                              att_reporter),
            trigger=(args.report_interval_iters, "iteration"),
        )
    # Run the training
    trainer.run()
    check_early_stop(trainer, args.epochs)
Beispiel #30
0
def train(args):
    """Train with the given args.

    Args:
        args (namespace): The program arguments.

    """
    set_deterministic_pytorch(args)
    if args.num_encs > 1:
        args = format_mulenc_args(args)

    # check cuda availability
    if not torch.cuda.is_available():
        logging.warning("cuda is not available")

    # get input and output dimension info
    with open(args.valid_json, "rb") as f:
        valid_json = json.load(f)["utts"]
    utts = list(valid_json.keys())
    idim_list = [
        int(valid_json[utts[0]]["input"][i]["shape"][-1]) for i in range(args.num_encs)
    ]
    odim = int(valid_json[utts[0]]["output"][0]["shape"][-1])
    for i in range(args.num_encs):
        logging.info("stream{}: input dims : {}".format(i + 1, idim_list[i]))
    logging.info("#output dims: " + str(odim))

    # specify attention, CTC, hybrid mode
    if args.mtlalpha == 1.0:
        mtl_mode = "ctc"
        logging.info("Pure CTC mode")
    elif args.mtlalpha == 0.0:
        mtl_mode = "att"
        logging.info("Pure attention mode")
    else:
        mtl_mode = "mtl"
        logging.info("Multitask learning mode")

    if (args.enc_init is not None or args.dec_init is not None) and args.num_encs == 1:
        model = load_trained_modules(idim_list[0], odim, args)
    else:
        model_class = dynamic_import(args.model_module)
        model = model_class(
            idim_list[0] if args.num_encs == 1 else idim_list, odim, args
        )
    assert isinstance(model, ASRInterface)

    logging.info(
        " Total parameter of the model = "
        + str(sum(p.numel() for p in model.parameters()))
    )

    if args.rnnlm is not None:
        rnnlm_args = get_model_conf(args.rnnlm, args.rnnlm_conf)
        rnnlm = lm_pytorch.ClassifierWithState(
            lm_pytorch.RNNLM(len(args.char_list), rnnlm_args.layer, rnnlm_args.unit)
        )
        torch_load(args.rnnlm, rnnlm)
        model.rnnlm = rnnlm

    # write model config
    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)
    model_conf = args.outdir + "/model.json"
    with open(model_conf, "wb") as f:
        logging.info("writing a model config file to " + model_conf)
        f.write(
            json.dumps(
                (idim_list[0] if args.num_encs == 1 else idim_list, odim, vars(args)),
                indent=4,
                ensure_ascii=False,
                sort_keys=True,
            ).encode("utf_8")
        )
    for key in sorted(vars(args).keys()):
        logging.info("ARGS: " + key + ": " + str(vars(args)[key]))

    reporter = model.reporter

    # check the use of multi-gpu
    if args.ngpu > 1:
        if args.batch_size != 0:
            logging.warning(
                "batch size is automatically increased (%d -> %d)"
                % (args.batch_size, args.batch_size * args.ngpu)
            )
            args.batch_size *= args.ngpu
        if args.num_encs > 1:
            # TODO(ruizhili): implement data parallel for multi-encoder setup.
            raise NotImplementedError(
                "Data parallel is not supported for multi-encoder setup."
            )

    # set torch device
    device = torch.device("cuda" if args.ngpu > 0 else "cpu")
    if args.train_dtype in ("float16", "float32", "float64"):
        dtype = getattr(torch, args.train_dtype)
    else:
        dtype = torch.float32
    model = model.to(device=device, dtype=dtype)

    # Setup an optimizer
    if args.opt == "adadelta":
        optimizer = torch.optim.Adadelta(
            model.parameters(), rho=0.95, eps=args.eps, weight_decay=args.weight_decay
        )
    elif args.opt == "adam":
        optimizer = torch.optim.Adam(model.parameters(), weight_decay=args.weight_decay)
    elif args.opt == "noam":
        from espnet.nets.pytorch_backend.transformer.optimizer import get_std_opt

        optimizer = get_std_opt(
            model.parameters(), args.adim, args.transformer_warmup_steps, args.transformer_lr
        )
    elif args.opt == "rmsprop":
        optimizer = torch.optim.RMSprop(model.parameters(), lr=0.0008, alpha=0.95)
    else:
        raise NotImplementedError("unknown optimizer: " + args.opt)

    # setup apex.amp
    if args.train_dtype in ("O0", "O1", "O2", "O3"):
        try:
            from apex import amp
        except ImportError as e:
            logging.error(
                f"You need to install apex for --train-dtype {args.train_dtype}. "
                "See https://github.com/NVIDIA/apex#linux"
            )
            raise e
        if args.opt == "noam":
            model, optimizer.optimizer = amp.initialize(
                model, optimizer.optimizer, opt_level=args.train_dtype
            )
        else:
            model, optimizer = amp.initialize(
                model, optimizer, opt_level=args.train_dtype
            )
        use_apex = True

        from espnet.nets.pytorch_backend.ctc import CTC

        amp.register_float_function(CTC, "loss_fn")
        amp.init()
        logging.warning("register ctc as float function")
    else:
        use_apex = False

    # FIXME: TOO DIRTY HACK
    setattr(optimizer, "target", reporter)
    setattr(optimizer, "serialize", lambda s: reporter.serialize(s))

    # Setup a converter
    if args.num_encs == 1:
        converter = CustomConverter(subsampling_factor=model.subsample[0], dtype=dtype)
    else:
        converter = CustomConverterMulEnc(
            [i[0] for i in model.subsample_list], dtype=dtype
        )

    # read json data
    with open(args.train_json, "rb") as f:
        train_json = json.load(f)["utts"]
    with open(args.valid_json, "rb") as f:
        valid_json = json.load(f)["utts"]

    use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0
    # make minibatch list (variable length)
    train = make_batchset(
        train_json,
        args.batch_size,
        args.maxlen_in,
        args.maxlen_out,
        args.minibatches,
        min_batch_size=args.ngpu if args.ngpu > 1 else 1,
        shortest_first=use_sortagrad,
        count=args.batch_count,
        batch_bins=args.batch_bins,
        batch_frames_in=args.batch_frames_in,
        batch_frames_out=args.batch_frames_out,
        batch_frames_inout=args.batch_frames_inout,
        iaxis=0,
        oaxis=0,
    )
    valid = make_batchset(
        valid_json,
        args.batch_size,
        args.maxlen_in,
        args.maxlen_out,
        args.minibatches,
        min_batch_size=args.ngpu if args.ngpu > 1 else 1,
        count=args.batch_count,
        batch_bins=args.batch_bins,
        batch_frames_in=args.batch_frames_in,
        batch_frames_out=args.batch_frames_out,
        batch_frames_inout=args.batch_frames_inout,
        iaxis=0,
        oaxis=0,
    )

    load_tr = LoadInputsAndTargets(
        mode="asr",
        load_output=True,
        preprocess_conf=args.preprocess_conf,
        preprocess_args={"train": True},  # Switch the mode of preprocessing
    )
    load_cv = LoadInputsAndTargets(
        mode="asr",
        load_output=True,
        preprocess_conf=args.preprocess_conf,
        preprocess_args={"train": False},  # Switch the mode of preprocessing
    )
    # hack to make batchsize argument as 1
    # actual bathsize is included in a list
    # default collate function converts numpy array to pytorch tensor
    # we used an empty collate function instead which returns list
    train_iter = ChainerDataLoader(
        dataset=TransformDataset(train, lambda data: converter([load_tr(data)])),
        batch_size=1,
        num_workers=args.n_iter_processes,
        shuffle=not use_sortagrad,
        collate_fn=lambda x: x[0],
    )
    valid_iter = ChainerDataLoader(
        dataset=TransformDataset(valid, lambda data: converter([load_cv(data)])),
        batch_size=1,
        shuffle=False,
        collate_fn=lambda x: x[0],
        num_workers=args.n_iter_processes,
    )

    # Set up a trainer
    updater = CustomUpdater(
        model,
        args.grad_clip,
        {"main": train_iter},
        optimizer,
        device,
        args.ngpu,
        args.grad_noise,
        args.accum_grad,
        use_apex=use_apex,
    )
    trainer = training.Trainer(updater, (args.epochs, "epoch"), out=args.outdir)

    if use_sortagrad:
        trainer.extend(
            ShufflingEnabler([train_iter]),
            trigger=(args.sortagrad if args.sortagrad != -1 else args.epochs, "epoch"),
        )

    # Resume from a snapshot
    if args.resume:
        logging.info("resumed from %s" % args.resume)
        torch_resume(args.resume, trainer)

    # Evaluate the model with the test dataset for each epoch
    if args.save_interval_iters > 0:
        trainer.extend(
            CustomEvaluator(model, {"main": valid_iter}, reporter, device, args.ngpu),
            trigger=(args.save_interval_iters, "iteration"),
        )
    else:
        trainer.extend(
            CustomEvaluator(model, {"main": valid_iter}, reporter, device, args.ngpu)
        )

    # Save attention weight each epoch
    if args.num_save_attention > 0 and args.mtlalpha != 1.0 and "transformer" in args.model_module:
        data = sorted(
            list(valid_json.items())[: args.num_save_attention],
            key=lambda x: int(x[1]["input"][0]["shape"][1]),
            reverse=True,
        )
        if hasattr(model, "module"):
            att_vis_fn = model.module.calculate_all_attentions
            plot_class = model.module.attention_plot_class
        else:
            att_vis_fn = model.calculate_all_attentions
            plot_class = model.attention_plot_class
        att_reporter = plot_class(
            att_vis_fn,
            data,
            args.outdir + "/att_ws",
            converter=converter,
            transform=load_cv,
            device=device,
        )
        trainer.extend(att_reporter, trigger=(1, "epoch"))
    else:
        att_reporter = None

    # Make a plot for training and validation values
    if args.num_encs > 1:
        report_keys_loss_ctc = [
                                   "main/loss_ctc{}".format(i + 1) for i in range(model.num_encs)
                               ] + ["validation/main/loss_ctc{}".format(i + 1) for i in range(model.num_encs)]
        report_keys_cer_ctc = [
                                  "main/cer_ctc{}".format(i + 1) for i in range(model.num_encs)
                              ] + ["validation/main/cer_ctc{}".format(i + 1) for i in range(model.num_encs)]
    trainer.extend(
        extensions.PlotReport(
            [
                "main/loss",
                "validation/main/loss",
                "main/loss_ctc",
                "validation/main/loss_ctc",
                "main/loss_att",
                "validation/main/loss_att",
            ]
            + ([] if args.num_encs == 1 else report_keys_loss_ctc),
            "epoch",
            file_name="loss.png",
        )
    )
    trainer.extend(
        extensions.PlotReport(
            ["main/acc", "validation/main/acc"], "epoch", file_name="acc.png"
        )
    )
    trainer.extend(
        extensions.PlotReport(
            ["main/cer_ctc", "validation/main/cer_ctc"]
            + ([] if args.num_encs == 1 else report_keys_loss_ctc),
            "epoch",
            file_name="cer.png",
        )
    )

    # Save best models
    trainer.extend(
        snapshot_object(model, "model.loss.best"),
        trigger=training.triggers.MinValueTrigger("validation/main/loss"),
    )
    if mtl_mode != "ctc":
        trainer.extend(
            snapshot_object(model, "model.acc.best"),
            trigger=training.triggers.MaxValueTrigger("validation/main/acc"),
        )

    # save snapshot which contains model and optimizer states
    if args.save_interval_iters > 0:
        trainer.extend(
            torch_snapshot(filename="snapshot.iter.{.updater.iteration}"),
            trigger=(args.save_interval_iters, "iteration"),
        )
    else:
        trainer.extend(torch_snapshot(), trigger=(1, "epoch"))

    # epsilon decay in the optimizer
    if args.opt == "adadelta":
        if args.criterion == "acc" and mtl_mode != "ctc":
            trainer.extend(
                restore_snapshot(
                    model, args.outdir + "/model.acc.best", load_fn=torch_load
                ),
                trigger=CompareValueTrigger(
                    "validation/main/acc",
                    lambda best_value, current_value: best_value > current_value,
                ),
            )
            trainer.extend(
                adadelta_eps_decay(args.eps_decay),
                trigger=CompareValueTrigger(
                    "validation/main/acc",
                    lambda best_value, current_value: best_value > current_value,
                ),
            )
        elif args.criterion == "loss":
            trainer.extend(
                restore_snapshot(
                    model, args.outdir + "/model.loss.best", load_fn=torch_load
                ),
                trigger=CompareValueTrigger(
                    "validation/main/loss",
                    lambda best_value, current_value: best_value < current_value,
                ),
            )
            trainer.extend(
                adadelta_eps_decay(args.eps_decay),
                trigger=CompareValueTrigger(
                    "validation/main/loss",
                    lambda best_value, current_value: best_value < current_value,
                ),
            )

    # lr decay in rmsprop
    if args.opt == "rmsprop":
        if args.criterion == "acc" and mtl_mode != "ctc":
            trainer.extend(
                restore_snapshot(
                    model, args.outdir + "/model.acc.best", load_fn=torch_load
                ),
                trigger=CompareValueTrigger(
                    "validation/main/acc",
                    lambda best_value, current_value: best_value > current_value,
                ),
            )
            trainer.extend(
                rmsprop_lr_decay(args.lr_decay),
                trigger=CompareValueTrigger(
                    "validation/main/acc",
                    lambda best_value, current_value: best_value > current_value,
                ),
            )
        elif args.criterion == "loss":
            trainer.extend(
                restore_snapshot(
                    model, args.outdir + "/model.loss.best", load_fn=torch_load
                ),
                trigger=CompareValueTrigger(
                    "validation/main/loss",
                    lambda best_value, current_value: best_value < current_value,
                ),
            )
            trainer.extend(
                rmsprop_lr_decay(args.lr_decay),
                trigger=CompareValueTrigger(
                    "validation/main/loss",
                    lambda best_value, current_value: best_value < current_value,
                ),
            )

    # Write a log of evaluation statistics for each epoch
    trainer.extend(
        extensions.LogReport(trigger=(args.report_interval_iters, "iteration"))
    )
    report_keys = [
                      "epoch",
                      "iteration",
                      "main/loss",
                      "main/loss_ctc",
                      "main/loss_att",
                      "validation/main/loss",
                      "validation/main/loss_ctc",
                      "validation/main/loss_att",
                      "main/acc",
                      "validation/main/acc",
                      "main/cer_ctc",
                      "validation/main/cer_ctc",
                      "elapsed_time",
                  ] + ([] if args.num_encs == 1 else report_keys_cer_ctc + report_keys_loss_ctc)
    if args.opt == "adadelta":
        trainer.extend(
            extensions.observe_value(
                "eps",
                lambda trainer: trainer.updater.get_optimizer("main").param_groups[0][
                    "eps"
                ],
            ),
            trigger=(args.report_interval_iters, "iteration"),
        )
        report_keys.append("eps")
    if args.opt == "rmsprop":
        trainer.extend(
            extensions.observe_value(
                "lr",
                lambda trainer: trainer.updater.get_optimizer("main").param_groups[0][
                    "lr"
                ],
            ),
            trigger=(args.report_interval_iters, "iteration"),
        )
        report_keys.append("lr")
    if args.report_cer:
        report_keys.append("validation/main/cer")
    if args.report_wer:
        report_keys.append("validation/main/wer")
    trainer.extend(
        extensions.PrintReport(report_keys),
        trigger=(args.report_interval_iters, "iteration"),
    )

    trainer.extend(extensions.ProgressBar(update_interval=args.report_interval_iters))
    set_early_stop(trainer, args)

    if args.tensorboard_dir is not None and args.tensorboard_dir != "":
        trainer.extend(
            TensorboardLogger(SummaryWriter(args.tensorboard_dir), att_reporter),
            trigger=(args.report_interval_iters, "iteration"),
        )
    # Run the training
    trainer.run()
    check_early_stop(trainer, args.epochs)
Beispiel #31
0
def train(model, train, test=None, params={}):
    params_local = default_params
    params_local.update(params)
    gpus = params_local["gpus"]
    print("GPUs ", gpus)

    if len(gpus) > 1:
        print("multiple gpus not implemented yet")
        exit(-1)

    if len(gpus) == 1:
        chainer.cuda.get_device(gpus[0]).use()
        model.to_gpu()

    optimizer = chainer.optimizers.Adam()
    optimizer.setup(model)

    train_iter = iterators.SerialIterator(
        train, batch_size=params_local["batch_size"], shuffle=True)

    if len(gpus) == 1:
        updater = training.StandardUpdater(train_iter,
                                           optimizer,
                                           device=gpus[0])
    else:
        updater = training.StandardUpdater(train_iter, optimizer)

    #if params["test_run"] is True:
    #trainer = training.Trainer(updater, stop_trigger=(1, 'epoch'), out=params["path_results"])
    #else:
    #trainer = training.Trainer(updater, stop_trigger=stop_trigger, out=params["path_results"])
    trainer = training.Trainer(updater,
                               stop_trigger=(params_local["nb_epoch"],
                                             'epoch'),
                               out=params_local["path_results"])

    #    trainer.extend(extensions.Evaluator(test_iter, model, device=args.gpu))
    if test is not None:
        test_iter = iterators.SerialIterator(
            test,
            batch_size=params_local["batch_size"],
            repeat=False,
            shuffle=False)
        if len(gpus) == 1:
            trainer.extend(extensions.Evaluator(test_iter, model))
        else:
            trainer.extend(
                extensions.Evaluator(test_iter, model, device=gpus[0]))

    #if len(gpus) == 1:
    #trainer.extend(extensions.Evaluator(test_iter, model, device=gpus[0], eval_func=model.test_eval_func))
    #else:
    #trainer.extend(extensions.Evaluator(test_iter, model, eval_func=model.test_eval_func))

    trainer.extend(extensions.dump_graph('main/loss'))
    trainer.extend(extensions.snapshot(filename='snapshot_last.npz'),
                   trigger=(1, 'epoch'))
    # trainer.extend(snapshot_best(filename='snapshot_best.npz'), trigger=(1, 'epoch'))
    trainer.extend(extensions.observe_lr())
    trainer.extend(extensions.observe_value("alpha",
                                            lambda _: optimizer.alpha))
    trainer.extend(extensions.ExponentialShift("alpha", 0.99),
                   trigger=(1, 'epoch'))
    trainer.extend(extensions.LogReport())
    trainer.extend(
        extensions.PrintReport([
            'epoch', 'main/loss', 'validation/main/loss', 'main/accuracy',
            'validation/main/accuracy', 'alpha', 'elapsed_time'
        ]))

    trainer.run()
Beispiel #32
0
def main():
    if not FLAGS.do_train and not FLAGS.do_predict and not FLAGS.do_print_test:
        raise ValueError(
            "At least one of `do_train` or `do_predict` must be True.")

    if FLAGS.do_train:
        if not FLAGS.train_file:
            raise ValueError(
                "If `do_train` is True, then `train_file` must be specified.")
    if FLAGS.do_predict:
        if not FLAGS.predict_file:
            raise ValueError(
                "If `do_predict` is True, then `predict_file` must be specified.")

    bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)

    if FLAGS.max_seq_length > bert_config.max_position_embeddings:
        raise ValueError(
            "Cannot use sequence length %d because the BERT model "
            "was only trained up to sequence length %d" %
            (FLAGS.max_seq_length, bert_config.max_position_embeddings))

    if not os.path.isdir(FLAGS.output_dir):
        os.makedirs(FLAGS.output_dir)

    tokenizer = tokenization.FullTokenizer(
        vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)

    train_examples = None
    num_train_steps = None
    num_warmup_steps = None
    if FLAGS.do_train:
        train_examples = read_squad_examples(
            input_file=FLAGS.train_file, is_training=True)
        train_features = convert_examples_to_features(
            train_examples, tokenizer, FLAGS.max_seq_length,
            FLAGS.doc_stride, FLAGS.max_query_length, is_training=True)
        num_train_steps = int(
            len(train_examples) / FLAGS.train_batch_size * FLAGS.num_train_epochs)
        num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion)

    bert = modeling.BertModel(config=bert_config)
    model = modeling.BertSQuAD(bert)
    if FLAGS.do_train:
        # If training, load BERT parameters only.
        ignore_names = ['output/W', 'output/b']
    else:
        # If only do_predict, load all parameters.
        ignore_names = None
    chainer.serializers.load_npz(
        FLAGS.init_checkpoint, model,
        ignore_names=ignore_names)

    if FLAGS.gpu >= 0:
        chainer.backends.cuda.get_device_from_id(FLAGS.gpu).use()
        model.to_gpu()

    if FLAGS.do_train:
        # Adam with weight decay only for 2D matrices
        optimizer = optimization.WeightDecayForMatrixAdam(
            alpha=1.,  # ignore alpha. instead, use eta as actual lr
            eps=1e-6, weight_decay_rate=0.01)
        optimizer.setup(model)
        optimizer.add_hook(chainer.optimizer.GradientClipping(1.))

        train_iter = chainer.iterators.SerialIterator(
            train_features, FLAGS.train_batch_size)
        converter = Converter(is_training=True)
        updater = training.updaters.StandardUpdater(
            train_iter, optimizer,
            converter=converter,
            device=FLAGS.gpu,
            loss_func=model.compute_loss)
        trainer = training.Trainer(
            updater, (num_train_steps, 'iteration'), out=FLAGS.output_dir)

        # learning rate (eta) scheduling in Adam
        lr_decay_init = FLAGS.learning_rate * \
            (num_train_steps - num_warmup_steps) / num_train_steps
        trainer.extend(extensions.LinearShift(  # decay
            'eta', (lr_decay_init, 0.), (num_warmup_steps, num_train_steps)))
        trainer.extend(extensions.WarmupShift(  # warmup
            'eta', 0., num_warmup_steps, FLAGS.learning_rate))
        trainer.extend(extensions.observe_value(
            'eta', lambda trainer: trainer.updater.get_optimizer('main').eta),
            trigger=(100, 'iteration'))  # logging

        trainer.extend(extensions.snapshot_object(
            model, 'model_snapshot_iter_{.updater.iteration}.npz'),
            trigger=(num_train_steps // 2, 'iteration'))  # TODO
        trainer.extend(extensions.LogReport(trigger=(100, 'iteration')))
        trainer.extend(extensions.PrintReport(
            ['iteration', 'main/loss',
             'main/accuracy', 'elapsed_time', 'eta']))
        trainer.extend(extensions.ProgressBar(update_interval=10))

        trainer.run()

    if FLAGS.do_predict:
        eval_examples = read_squad_examples(
            input_file=FLAGS.predict_file, is_training=False)
        eval_features = convert_examples_to_features(
            eval_examples, tokenizer, FLAGS.max_seq_length,
            FLAGS.doc_stride, FLAGS.max_query_length, is_training=False)
        test_iter = chainer.iterators.SerialIterator(
            eval_features, FLAGS.predict_batch_size,
            repeat=False, shuffle=False)
        converter = Converter(is_training=False)

        print('Evaluating ...')
        evaluate(eval_examples, test_iter, model,
                 converter=converter, device=FLAGS.gpu,
                 predict_func=model.predict)
        print('Finished.')