Ejemplo n.º 1
0
def train_model(model, train_dataset, valid_dataset, fields, text_model,
                text_train_dataset, text_valid_dataset, text_fields,
                discrim_models, discrim_optims, gen_optims, optim, text_optim,
                model_opt):

    train_iter = make_train_data_iter(train_dataset, opt)
    valid_iter = make_valid_data_iter(valid_dataset, opt)

    text_train_iter = make_train_data_iter(text_train_dataset, opt)
    text_valid_iter = make_valid_data_iter(text_valid_dataset, opt)

    train_loss = make_loss_compute(model, fields["tgt"].vocab, train_dataset,
                                   opt)
    text_loss = make_loss_compute(model, fields["tgt"].vocab, train_dataset,
                                  opt, True)
    valid_loss = make_loss_compute(model, fields["tgt"].vocab, valid_dataset,
                                   opt)
    text_valid_loss = make_loss_compute(model, fields["tgt"].vocab,
                                        text_valid_dataset, opt, True)

    trunc_size = opt.truncated_decoder  # Badly named...
    shard_size = opt.max_generator_batches
    data_type = train_dataset.data_type

    trainer = onmt.AudioTextTrainerAdv(
        model, train_iter, valid_iter, text_model, text_train_iter,
        text_valid_iter, train_loss, text_loss, valid_loss, text_valid_loss,
        optim, text_optim, discrim_models,
        [model_opt.gen_label, model_opt.gen_label], gen_optims,
        model_opt.gen_lambda, trunc_size, shard_size, data_type,
        model_opt.mult)

    train_iter = make_train_data_iter(train_dataset, opt, 32)
    text_train_iter = make_train_data_iter(text_train_dataset, opt, 32)
    discrim_trainer = onmt.DiscrimTrainer(discrim_models,
                                          [train_iter, text_train_iter],
                                          [valid_iter, text_valid_iter],
                                          discrim_optims, [0.1, 0.9],
                                          shard_size)

    for epoch in range(opt.start_epoch, opt.epochs + 1):
        print('')

        if epoch > 1:
            src_valid_stats, tgt_valid_stats, st_valid_stats = discrim_trainer.validate(
            )
            print('(before) Discrim validation src loss: %g' %
                  src_valid_stats.loss)
            print('(before) Discrim validation src/tgt loss: %g' %
                  st_valid_stats.loss)
            print('(before) Discrim validation tgt loss: %g' %
                  tgt_valid_stats.loss)
        src_train_stats, tgt_train_stats = discrim_trainer.train(
            epoch, discrim_report_func)
        print('Discrim src loss: %g' % src_train_stats.loss)
        print('Discrim tgt loss: %g' % tgt_train_stats.loss)
        src_valid_stats, tgt_valid_stats, st_valid_stats = discrim_trainer.validate(
        )
        print('(after) Discrim validation src loss: %g' % src_valid_stats.loss)
        print('(after) Discrim validation src/tgt loss: %g' %
              st_valid_stats.loss)
        print('(after) Discrim validation tgt loss: %g' % tgt_valid_stats.loss)

        # 1. Train for one epoch on the training set.
        train_stats, text_train_stats = trainer.train(epoch, report_func)
        print('Train perplexity: %g' % train_stats.ppl())
        print('Train accuracy: %g' % train_stats.accuracy())
        print('Text perplexity: %g' % text_train_stats.ppl())
        print('Text accuracy: %g' % text_train_stats.accuracy())

        # 2. Validate on the validation set.
        valid_stats = trainer.validate()
        print('Validation perplexity: %g' % valid_stats.ppl())
        print('Validation accuracy: %g' % valid_stats.accuracy())
        text_valid_stats = trainer.validate_text()
        print('Text Validation perplexity: %g' % text_valid_stats.ppl())
        print('Text Validation accuracy: %g' % text_valid_stats.accuracy())

        # 3. Log to remote server.
        if opt.exp_host:
            train_stats.log("train", experiment, optim.lr)
            valid_stats.log("valid", experiment, optim.lr)

        # 4. Update the learning rate
        trainer.epoch_step(valid_stats.ppl(), epoch)

        # 5. Drop a checkpoint if needed.
        if epoch >= opt.start_checkpoint_at:
            trainer.drop_checkpoint(model_opt, epoch, fields, valid_stats)
            discrim_trainer.drop_checkpoint(model_opt, epoch, fields,
                                            valid_stats)
Ejemplo n.º 2
0
def train_model(auto_models, valid_model, train_data, valid_data, fields_list,
                valid_fields, optims, discrim_models, discrim_optims, labels):

    #     train_model(models, valid_model, train, valid, fields, fields_valid, optims,
    #                 discrim_models, discrim_optims, advers_optims, labels)

    trunc_size = opt.truncated_decoder  # Badly named...
    shard_size = opt.max_generator_batches

    valid_iter = make_valid_data_iter(valid_data, opt)
    valid_loss = make_loss_compute(valid_model, valid_fields["tgt"].vocab,
                                   valid_data, opt)
    trainers = []
    trainers.append(
        onmt.Trainer(valid_model, valid_iter, valid_iter, valid_loss,
                     valid_loss, optims[0], trunc_size, shard_size))

    for model, discrim_model, optim, label, train, fields in zip(
            auto_models, discrim_models, optims, labels, train_data,
            fields_list):
        train_iter = make_train_data_iter(train, opt)
        train_loss = make_loss_compute(model, fields["tgt"].vocab, train, opt)
        trainers.append(
            onmt.AdvTrainer(model, discrim_model, train_iter, valid_iter,
                            train_loss, valid_loss, optim, label, trunc_size,
                            shard_size))

    discrim_trainers = []
    for model, discrim_optim, data in zip(discrim_models, discrim_optims,
                                          train_data):
        train_iter = make_train_data_iter(data, opt)
        discrim_trainers.append(
            onmt.DiscrimTrainer(model, train_iter, discrim_optim, shard_size))
    #for model, optim, data in zip(discrim_models, advers_optims, advers_data):
    #    train_iter = make_train_data_iter(data, opt)
    #    discrim_trainers.append(onmt.DiscrimTrainer(model, train_iter, optim, shard_size))

    for epoch in range(opt.start_epoch, opt.epochs + 1):
        print('')

        for label, trainer in zip(labels, discrim_trainers):
            # 1. Train for one epoch on the training set.
            train_stats = trainer.train(epoch, label, discrim_report_func)
            print('Train loss: %g' % train_stats.loss)
            #print('Train accuracy: %g' % train_stats.accuracy())

            if opt.exp_host:
                train_stats.log("train", experiment, optim.lr)

        for trainer in trainers[1:]:
            # 1. Train for one epoch on the training set.
            train_stats = trainer.train(epoch, report_func)
            print('Train perplexity: %g' % train_stats.ppl())
            print('Train accuracy: %g' % train_stats.accuracy())

            if opt.exp_host:
                train_stats.log("train", experiment, optim.lr)

        # 2. Validate on the validation set.
        valid_stats = trainers[0].validate()
        print('Validation perplexity: %g' % valid_stats.ppl())
        print('Validation accuracy: %g' % valid_stats.accuracy())

        # 3. Log to remote server.
        if opt.exp_host:
            valid_stats.log("valid", experiment, optim.lr)
        '''
        for trainer in trainers[1:]:
            # 4. Update the learning rate
            trainer.epoch_step(valid_stats.ppl(), epoch)
            
        for trainer in discrim_trainers:
            # 4. Update the learning rate
            trainer.epoch_step(valid_stats.ppl(), epoch)
        '''

        # 5. Drop a checkpoint if needed.
        if epoch >= opt.start_checkpoint_at:
            trainers[0].drop_checkpoint(opt, epoch, fields, valid_stats)
def train_model(model, train_dataset, valid_dataset, fields, 
                text_model, text_train_dataset, text_valid_dataset, text_fields,
                speech_model, speech_train_dataset,
                discrim_models, discrim_optims,
                optim, adv_optim, speech_optim, model_opt, big_text):

    train_iter = make_train_data_iter(train_dataset, opt)
    valid_iter = make_valid_data_iter(valid_dataset, opt)
    text_valid_iter = make_valid_data_iter(text_valid_dataset, opt)

    text_train_iter = make_train_data_iter(text_train_dataset, opt)
    try:
        text_train_iter.dR = model_opt.delete_rate
    except:
        text_train_iter.dR = 0.2
        
    speech_train_iter = make_train_data_iter(speech_train_dataset, opt)

    train_loss = make_loss_compute(model, fields["tgt"].vocab,
                                   train_dataset, opt)
    text_loss = make_loss_compute(model, fields["tgt"].vocab,
                                   train_dataset, opt, True)

    valid_loss = make_loss_compute(model, fields["tgt"].vocab,
                                   valid_dataset, opt)
    text_valid_loss = make_loss_compute(model, fields["tgt"].vocab,
                                   text_valid_dataset, opt)

    trunc_size = opt.truncated_decoder  # Badly named...
    shard_size = opt.max_generator_batches
    data_type = train_dataset.data_type

    speech_lambda = opt.auto_lambda
    if model_opt.weighted:
       speech_lambda = speech_lambda/float(opt.mult)

    print "label:", model_opt.gen_label
    try:
        print model_opt.unsup
    except:
        model_opt.unsup = False

    if opt.no_adv:
        discrim_models = [None,None]

    try:
        if opt.feature_match:
            trainer = onmt.AudioTextSpeechTrainerAdvFMatch(model, train_iter, valid_iter,
                                                           text_model, text_train_iter, text_valid_iter, 
                                                           speech_model, speech_train_iter,
                                                           train_loss, text_loss, valid_loss, text_valid_loss, optim, adv_optim, speech_optim,
                                                           discrim_models, [model_opt.gen_label, model_opt.gen_label], model_opt.gen_lambda,  speech_lambda, 
                                                           trunc_size, shard_size, data_type,
                                                           model_opt.mult, model_opt.t_mult, model_opt.unsup, big_text=big_text)
        else:
            trainer = onmt.AudioTextSpeechTrainerAdv(model, train_iter, valid_iter,
                                                     text_model, text_train_iter, text_valid_iter, 
                                                     speech_model, speech_train_iter,
                                                     train_loss, text_loss, valid_loss, text_valid_loss, optim, adv_optim, speech_optim,
                                                     discrim_models, [model_opt.gen_label, model_opt.gen_label], model_opt.gen_lambda,  speech_lambda, 
                                                     trunc_size, shard_size, data_type,
                                                     model_opt.mult, model_opt.t_mult, model_opt.unsup, big_text=big_text)
    except:
        trainer = onmt.AudioTextSpeechTrainerAdv(model, train_iter, valid_iter,
                                                 text_model, text_train_iter, text_valid_iter, 
                                                 speech_model, speech_train_iter,
                                                 train_loss, text_loss, valid_loss, text_valid_loss, optim, adv_optim, speech_optim,
                                                 discrim_models, [model_opt.gen_label, model_opt.gen_label], model_opt.gen_lambda,  speech_lambda, 
                                                 trunc_size, shard_size, data_type,
                                                 model_opt.mult, model_opt.t_mult, model_opt.unsup, big_text=big_text)

    if opt.ff_speech_decoder:
        trainer.ff = True

    if not opt.no_adv:
        speech_train_iter = make_train_data_iter(speech_train_dataset, opt, 32)
        text_train_iter = make_train_data_iter(text_train_dataset, opt, 32)
        try:
            text_train_iter.dR = model_opt.delete_rate
        except:
            text_train_iter.dR = 0.2
        discrim_trainer = onmt.DiscrimTrainer(discrim_models, [speech_train_iter, text_train_iter], [valid_iter, text_valid_iter],
                                              discrim_optims, [0.1, 0.9], shard_size, big_text)

    if model_opt.unsup:
        override = 50
        print("OVERRIDE: " + str(override))
    else:
        #override = -1
        override = 2000
        print("OVERRIDE: " + str(override))

    try:
        model_opt.start_mask = max(0, model_opt.start_mask)
    except AttributeError:
        model_opt.start_mask = 0
        model_opt.end_mask = 0

    advOnly = False
    
    if big_text:
        nText = len(glob.glob(opt.text_data + '.train.[0-9]*.pt'))
        text_idx = opt.start_epoch % nText
        print "idx:", text_idx, opt.start_epoch, nText
    for epoch in range(opt.start_epoch, opt.epochs + 1):
        print('')

        if big_text:
            try:
                text_train_dataset = torch.load(opt.text_data + '.train.' + str(text_idx) + '.pt')
            except:
                text_train_dataset = torch.load(opt.text_data + '.train.1.pt')
                text_idx = 1
            text_train_dataset.fields = text_fields
            print "LOADED BIG TEXT:", text_idx
            text_idx += 1

        if not opt.no_adv:
            if epoch > 1:
                src_valid_stats, tgt_valid_stats, st_valid_stats = discrim_trainer.validate()
                print('(before) Discrim validation src loss: %g' % src_valid_stats.loss)
                print('(before) Discrim validation src/tgt loss: %g' % st_valid_stats.loss)
                print('(before) Discrim validation tgt loss: %g' % tgt_valid_stats.loss)
            if big_text:
                text_train_iter = make_train_data_iter(text_train_dataset, opt, 32)
                text_train_iter.dR = model_opt.delete_rate
                src_train_stats, tgt_train_stats = discrim_trainer.train(epoch, discrim_report_func, text=text_train_iter, startMask=model_opt.start_mask, endMask=model_opt.end_mask) #, override)
            else:
                src_train_stats, tgt_train_stats = discrim_trainer.train(epoch, discrim_report_func, startMask=model_opt.start_mask, endMask=model_opt.end_mask) #, override)
            print('Discrim src loss: %g' % src_train_stats.loss)
            print('Discrim tgt loss: %g' % tgt_train_stats.loss)
            src_valid_stats, tgt_valid_stats, st_valid_stats = discrim_trainer.validate()
            print('(after) Discrim validation src loss: %g' % src_valid_stats.loss)
            print('(after) Discrim validation src/tgt loss: %g' % st_valid_stats.loss)
            print('(after) Discrim validation tgt loss: %g' % tgt_valid_stats.loss)

        # 1. Train for one epoch on the training set.
        if big_text:
            text_train_iter = make_train_data_iter(text_train_dataset, opt)
            text_train_iter.dR = model_opt.delete_rate
            train_stats, text_train_stats, speech_train_stats, discrim_train_stats = trainer.train(epoch, report_func, override, text=text_train_iter,
                                                                              startMask=model_opt.start_mask, endMask=model_opt.end_mask, advOnly=advOnly)
        else:
            train_stats, text_train_stats, speech_train_stats, discrim_train_stats = trainer.train(epoch, report_func, override, startMask=model_opt.start_mask,
                                                                                                   endMask=model_opt.end_mask, advOnly=advOnly)
        if not opt.unsup and not advOnly:
            print('Train perplexity: %g' % train_stats.ppl())
            print('Train accuracy: %g' % train_stats.accuracy())
        if not advOnly:
            print('Text perplexity: %g' % text_train_stats.ppl())
            print('Text accuracy: %g' % text_train_stats.accuracy())
        try:
            print('Speech MSE: %g' % speech_train_stats.loss)
        except:
            pass
        try:
            print('Discrim Loss: %g' % discrim_train_stats.loss)
        except:
            pass

        # 2. Validate on the validation set.
        valid_stats = trainer.validate()
        print('Validation perplexity: %g' % valid_stats.ppl())
        print('Validation accuracy: %g' % valid_stats.accuracy())
        text_valid_stats = trainer.validate_text()
        print('Text validation perplexity: %g' % text_valid_stats.ppl())
        print('Text validation accuracy: %g' % text_valid_stats.accuracy())

        # 3. Log to remote server.
        if opt.exp_host:
            train_stats.log("train", experiment, optim.lr)
            valid_stats.log("valid", experiment, optim.lr)

        # 4. Update the learning rate
        trainer.epoch_step(valid_stats.ppl(), epoch)

        # 5. Drop a checkpoint if needed.
        if epoch >= opt.start_checkpoint_at:
            trainer.drop_checkpoint(model_opt, epoch, fields, valid_stats)
            if not opt.no_adv:
                discrim_trainer.drop_checkpoint(model_opt, epoch, fields, valid_stats)
Ejemplo n.º 4
0
def train_model(auto_models, valid_model, train_data, valid_data, fields_list,
                valid_fields, optims, discrim_models, discrim_optims, labels,
                advers_optims):

    #     train_model(models, valid_model, train, valid, fields, fields_valid, optims,
    #                 discrim_models, discrim_optims, advers_optims, labels)

    trunc_size = opt.truncated_decoder  # Badly named...
    shard_size = opt.max_generator_batches

    valid_iter = make_valid_data_iter(valid_data, opt)
    valid_loss = make_loss_compute(valid_model, valid_fields["tgt"].vocab,
                                   valid_data, opt)
    valid_trainer = onmt.Trainer(valid_model, valid_iter, valid_iter,
                                 valid_loss, valid_loss, optims[0], trunc_size,
                                 shard_size)

    src_train_iter = make_train_data_iter(train_data[0], opt)
    src_train_loss = make_loss_compute(auto_models[0],
                                       fields_list[0]["tgt"].vocab,
                                       train_data[0], opt)
    src_trainer = onmt.Trainer(auto_models[0], src_train_iter, valid_iter,
                               src_train_loss, valid_loss, optims[0],
                               trunc_size, shard_size)

    tgt_train_iter = make_train_data_iter(train_data[1], opt)
    tgt_train_loss = make_loss_compute(auto_models[1],
                                       fields_list[1]["tgt"].vocab,
                                       train_data[1], opt)
    tgt_trainer = onmt.Trainer(auto_models[1], tgt_train_iter, valid_iter,
                               tgt_train_loss, valid_loss, optims[1],
                               trunc_size, shard_size)

    src_train_iter = make_train_data_iter(train_data[0], opt)
    tgt_train_iter = make_train_data_iter(train_data[1], opt)
    src_train_loss = make_loss_compute(auto_models[0],
                                       fields_list[0]["tgt"].vocab,
                                       train_data[0], opt)
    tgt_train_loss = make_loss_compute(auto_models[1],
                                       fields_list[1]["tgt"].vocab,
                                       train_data[1], opt)
    unsup_trainer = onmt.UnsupTrainer(
        auto_models, [None, None], discrim_models,
        [src_train_iter, tgt_train_iter], valid_iter,
        [src_train_loss, tgt_train_loss], [None, None], valid_loss,
        [optims[0], None], [None, None], [0.9, 0.9], trunc_size, shard_size)

    src_train_iter = make_train_data_iter(train_data[0], opt)
    tgt_train_iter = make_train_data_iter(train_data[1], opt)
    discrim_trainer = onmt.DiscrimTrainer(discrim_models,
                                          [src_train_iter, tgt_train_iter],
                                          discrim_optims, labels, shard_size)

    src_train_iter = make_train_data_iter(train_data[0], opt)
    tgt_train_iter = make_train_data_iter(train_data[1], opt)
    advers_trainer = onmt.DiscrimTrainer(discrim_models,
                                         [src_train_iter, tgt_train_iter],
                                         advers_optims, [0.9, 0.9], shard_size)
    '''
    for epoch in range(10):
        train_stats = discrim_trainer.train(epoch, discrim_report_func)
        print('Discrim Train loss: %g' % train_stats.loss)
    for epoch in range(10):
        train_stats = discrim_trainer.train(epoch, discrim_report_func)
        print('Discrim Train loss: %g' % train_stats.loss)
        train_stats = advers_trainer.train(epoch, discrim_report_func)
        print('Advers Train loss: %g' % train_stats.loss)
    '''

    for epoch in range(opt.start_epoch, opt.epochs + 1):
        print('')

        # 1. Train for one epoch on the training set.
        train_stats = src_trainer.train(epoch, discrim_report_func)
        print('SRC Train perplexity: %g' % train_stats.ppl())
        print('SRC Train accuracy: %g' % train_stats.accuracy())

        # 1. Train for one epoch on the training set.
        train_stats = tgt_trainer.train(epoch, discrim_report_func)
        print('TGT Train perplexity: %g' % train_stats.ppl())
        print('TGT Train accuracy: %g' % train_stats.accuracy())

        if opt.exp_host:
            train_stats.log("train", experiment, optim.lr)

    for epoch in range(opt.start_epoch, opt.epochs + 1):
        print('')

        # 1. Train for one epoch on the training set.
        train_stats = src_trainer.train(epoch, discrim_report_func)
        print('SRC Train perplexity: %g' % train_stats.ppl())
        print('SRC Train accuracy: %g' % train_stats.accuracy())

        train_stats = discrim_trainer.train(epoch, discrim_report_func)
        print('Discrim Train loss: %g' % train_stats.loss)

        train_stats = advers_trainer.train(epoch, discrim_report_func)
        print('Advers Train loss: %g' % train_stats.loss)
        '''
        train_stats = unsup_trainer.train(epoch, report_func)
        print('Train perplexity: %g' % train_stats.ppl())
        print('Train accuracy: %g' % train_stats.accuracy())
        '''

        # 2. Validate on the validation set.
        valid_stats = valid_trainer.validate()
        print('Validation perplexity: %g' % valid_stats.ppl())
        print('Validation accuracy: %g' % valid_stats.accuracy())

        # 3. Log to remote server.
        if opt.exp_host:
            valid_stats.log("valid", experiment, optim.lr)

        # 5. Drop a checkpoint if needed.
        if epoch >= opt.start_checkpoint_at:
            valid_trainer.drop_checkpoint(opt, epoch, valid_fields,
                                          valid_stats)