Esempio n. 1
0
# test_path = "data/processed/order/test.txt"
model_path = "models/enc_dec_500/sen.len.enc.dec.500.model.net"
data_dir = "data/idx/sen_len/"
rep_dir = "data/enc_dec_500/"
# order_path = "data/idx/order/test.txt.order.txt"
task = Tests.SENTENCE_LENGTH


x_train, y_train, x_test, y_test, x_val, y_val = gd.build_data(rep_dir + "word_repr.npy",
                                                               data_dir + "train.txt",
                                                               rep_dir + "train.rep.npy",
                                                               data_dir + "test.txt", rep_dir + "test.rep.npy",
                                                               data_dir + "val.txt",
                                                               rep_dir + "val.rep.npy", task=task)
# build the model
model = m.build_model(input_dim=input_dim, output_dim=output_dim)
model.load_weights(model_path)
# get predictions
y_hat = model.predict_classes(x_test, batch_size=batch_size)

# utils.plot_accuracy_vs_distance(y_test, y_hat, order_path)

print("")
# print("Test set: %s" % test_path)
print("Accuracy on the test set: %s" % (float(np.sum(y_hat == np.argmax(y_test, axis=1))) / len(y_hat)))
print(confusion_matrix(np.argmax(y_test, axis=1), y_hat))

# l = y_hat != np.argmax(y_test, axis=1)
# # y = [i for i, x in enumerate(l) if x]
# np.savetxt("random_300_acc_w2v.txt", l)
Esempio n. 2
0
def main(params):

    # initialize the multi-GPU / multi-node training
    init_distributed_mode(params)

    # initialize the experiment / load data
    logger = initialize_exp(params)

    # initialize SLURM signal handler for time limit / pre-emption
    if params.is_slurm_job:
        init_signal_handler()

    if params.dataset == "imagenet":
        params.num_classes = 1000
        params.img_size = 256
        params.crop_size = 224
    else:
        if params.dataset == "cifar10":
            params.num_classes = 10
        elif params.dataset == "cifar100":
            params.num_classes = 100
        else:
            assert False, "Dataset unbeknownst to me"

        params.img_size = 40
        params.crop_size = 32

    # data loaders / samplers
    train_data_loader, train_sampler = get_data_loader(
        img_size=params.img_size,
        crop_size=params.crop_size,
        shuffle=True,
        batch_size=params.batch_size,
        nb_workers=params.nb_workers,
        distributed_sampler=params.multi_gpu,
        dataset=params.dataset,
        transform=params.transform,
        split='valid' if params.debug_train else params.split_train,
    )

    valid_data_loader, _ = get_data_loader(
        img_size=params.img_size,
        crop_size=params.crop_size,
        shuffle=False,
        batch_size=params.batch_size,
        nb_workers=params.nb_workers,
        distributed_sampler=False,
        dataset=params.dataset,
        transform='center',
        split='valid',
    )

    # build model / cuda
    logger.info("Building %s model ..." % params.architecture)
    model = build_model(params)
    model.cuda()

    # distributed  # TODO: check this https://github.com/NVIDIA/apex/blob/master/examples/imagenet/main.py#L142
    if params.multi_gpu:
        logger.info("Using nn.parallel.DistributedDataParallel ...")
        model = nn.parallel.DistributedDataParallel(
            model,
            device_ids=[params.local_rank],
            output_device=params.local_rank,
            broadcast_buffers=True)

    # build trainer / reload potential checkpoints / build evaluator
    trainer = Trainer(model=model, params=params)
    trainer.reload_checkpoint()
    evaluator = Evaluator(trainer, params)

    # evaluation
    if params.eval_only:
        scores = evaluator.run_all_evals(trainer,
                                         evals=['classif', 'recognition'],
                                         data_loader=valid_data_loader)

        for k, v in scores.items():
            logger.info('%s -> %.6f' % (k, v))
        logger.info("__log__:%s" % json.dumps(scores))
        exit()

    # training
    for epoch in range(trainer.epoch, params.epochs):

        # update epoch / sampler / learning rate
        trainer.epoch = epoch
        logger.info("============ Starting epoch %i ... ============" %
                    trainer.epoch)
        if params.multi_gpu:
            train_sampler.set_epoch(epoch)

        # update learning rate
        trainer.update_learning_rate()

        # train
        for i, (images, targets) in enumerate(train_data_loader):
            trainer.classif_step(images, targets)
            trainer.iter()

        logger.info("============ End of epoch %i ============" %
                    trainer.epoch)

        # evaluate classification accuracy
        scores = evaluator.run_all_evals(trainer,
                                         evals=['classif'],
                                         data_loader=valid_data_loader)

        for name, val in trainer.get_scores().items():
            scores[name] = val

        # print / JSON log
        for k, v in scores.items():
            logger.info('%s -> %.6f' % (k, v))
        if params.is_master:
            logger.info("__log__:%s" % json.dumps(scores))

        # end of epoch
        trainer.save_best_model(scores)
        trainer.save_periodic()
        trainer.end_epoch(scores)
Esempio n. 3
0
def main(params):

    # initialize the multi-GPU / multi-node training
    init_distributed_mode(params)

    # initialize the experiment
    logger = initialize_exp(params)

    # initialize SLURM signal handler for time limit / pre-emption
    init_signal_handler()

    # load data
    data = load_data(params)

    # load checkpoint
    if params.model_path != "":
        reloaded = torch.load(params.model_path)
        model_params = AttrDict(reloaded['params'])
        dico = Dictionary(reloaded['dico_id2word'], reloaded['dico_word2id'],
                          reloaded['dico_counts'])
        encoder = TransformerModel(model_params,
                                   dico,
                                   is_encoder=True,
                                   with_output=True).cuda().eval()
        decoder = TransformerModel(model_params,
                                   dico,
                                   is_encoder=False,
                                   with_output=True).cuda().eval()
        encoder = TransformerModel(model_params,
                                   dico,
                                   is_encoder=True,
                                   with_output=True).cuda().eval()
        decoder = TransformerModel(model_params,
                                   dico,
                                   is_encoder=False,
                                   with_output=True).cuda().eval()
        encoder.load_state_dict(reloaded['encoder'])
        decoder.load_state_dict(reloaded['decoder'])
        logger.info("Supported languages: %s" %
                    ", ".join(model_params.lang2id.keys()))
    else:
        # build model
        if params.encoder_only:
            model = build_model(params, data['dico'])
        else:
            encoder, decoder = build_model(params, data['dico'])

    # build trainer, reload potential checkpoints / build evaluator
    if params.encoder_only:
        trainer = SingleTrainer(model, data, params)
        evaluator = SingleEvaluator(trainer, data, params)
    else:
        trainer = EncDecTrainer(encoder, decoder, data, params)
        evaluator = EncDecEvaluator(trainer, data, params)

    # evaluation
    if params.eval_only:
        scores = evaluator.run_all_evals(trainer)
        for k, v in scores.items():
            logger.info("%s -> %.6f" % (k, v))
        logger.info("__log__:%s" % json.dumps(scores))
        exit()

    # set sampling probabilities for training
    set_sampling_probs(data, params)

    # language model training
    for _ in range(params.max_epoch):

        logger.info("============ Starting epoch %i ... ============" %
                    trainer.epoch)

        trainer.n_sentences = 0

        while trainer.n_sentences < trainer.epoch_size:

            # CLM steps
            for lang1, lang2 in shuf_order(params.clm_steps, params):
                trainer.clm_step(lang1, lang2, params.lambda_clm)

            # MLM steps (also includes TLM if lang2 is not None)
            for lang1, lang2 in shuf_order(params.mlm_steps, params):
                trainer.mlm_step(lang1, lang2, params.lambda_mlm)

            # parallel classification steps
            for lang1, lang2 in shuf_order(params.pc_steps, params):
                trainer.pc_step(lang1, lang2, params.lambda_pc)

            # denoising auto-encoder steps
            for lang in shuf_order(params.ae_steps):
                trainer.mt_step(lang, lang, params.lambda_ae)

            # machine translation steps
            for lang1, lang2 in shuf_order(params.mt_steps, params):
                trainer.mt_step(lang1, lang2, params.lambda_mt)

            # back-translation steps
            for lang1, lang2, lang3 in shuf_order(params.bt_steps):
                trainer.bt_step(lang1, lang2, lang3, params.lambda_bt)

            trainer.iter()

        logger.info("============ End of epoch %i ============" %
                    trainer.epoch)

        # evaluate perplexity
        scores = evaluator.run_all_evals(trainer)

        # print / JSON log
        for k, v in scores.items():
            logger.info("%s -> %.6f" % (k, v))
        if params.is_master:
            logger.info("__log__:%s" % json.dumps(scores))

        # end of epoch
        trainer.save_best_model(scores)
        trainer.save_periodic()
        trainer.end_epoch(scores)
Esempio n. 4
0
def main(params):

    # initialize the multi-GPU / multi-node training
    init_distributed_mode(params)

    # initialize the experiment
    logger = initialize_exp(params)

    # initialize SLURM signal handler for time limit / pre-emption
    init_signal_handler()

    # load data
    data = load_data(params)

    # build model
    if params.encoder_only:
        model = build_model(params, data['dico'])
    else:
        encoder, decoder = build_model(params, data['dico'])

    # float16
    if params.fp16:
        assert torch.backends.cudnn.enabled
        if params.encoder_only:
            model = network_to_half(model)
        else:
            encoder = network_to_half(encoder)
            decoder = network_to_half(decoder)

    # distributed
    if params.multi_gpu:
        logger.info("Using nn.parallel.DistributedDataParallel ...")
        if params.encoder_only:
            model = apex.parallel.DistributedDataParallel(model, delay_allreduce=True)
        else:
            encoder = apex.parallel.DistributedDataParallel(encoder, delay_allreduce=True)
            decoder = apex.parallel.DistributedDataParallel(decoder, delay_allreduce=True)

    # build trainer, reload potential checkpoints / build evaluator
    if params.encoder_only:
        trainer = SingleTrainer(model, data, params)
        evaluator = SingleEvaluator(trainer, data, params)
    else:
        trainer = EncDecTrainer(encoder, decoder, data, params)
        evaluator = EncDecEvaluator(trainer, data, params)

    # evaluation
    if params.eval_only:
        scores = evaluator.run_all_evals(trainer)
        for k, v in scores.items():
            logger.info("%s -> %.6f" % (k, v))
        logger.info("__log__:%s" % json.dumps(scores))
        with open(params.eval_output, 'w') as f:
            scores['temperature'] = params.softmax_temperature
            json.dump(dict(scores), f, indent=4)
        exit()

    # set sampling probabilities for training
    set_sampling_probs(data, params)

    # language model training
    for _ in range(params.max_epoch):

        logger.info("============ Starting epoch %i ... ============" % trainer.epoch)

        trainer.n_sentences = 0

        while trainer.n_sentences < trainer.epoch_size:

            # CLM steps
            for lang1, lang2 in shuf_order(params.clm_steps, params):
                trainer.clm_step(lang1, lang2, params.lambda_clm)

            # MLM steps (also includes TLM if lang2 is not None)
            for lang1, lang2 in shuf_order(params.mlm_steps, params):
                trainer.mlm_step(lang1, lang2, params.lambda_mlm)

            # parallel classification steps
            for lang1, lang2 in shuf_order(params.pc_steps, params):
                trainer.pc_step(lang1, lang2, params.lambda_pc)

            # denoising auto-encoder steps
            for lang in shuf_order(params.ae_steps):
                trainer.mt_step(lang, lang, params.lambda_ae)

            # mass prediction steps
            for lang in shuf_order(params.mass_steps):
                trainer.mass_step(lang, params.lambda_mass)

            # machine translation steps
            for lang1, lang2 in shuf_order(params.mt_steps, params):
                trainer.mt_step(lang1, lang2, params.lambda_mt)

            # back-translation steps
            for lang1, lang2, lang3 in shuf_order(params.bt_steps):
                trainer.bt_step(lang1, lang2, lang3, params.lambda_bt)
            
            # back-parallel steps
            for lang1, lang2 in shuf_order(params.bmt_steps, params):
                trainer.bmt_step(lang1, lang2, params.lambda_bmt)

            trainer.iter()

        logger.info("============ End of epoch %i ============" % trainer.epoch)

        # evaluate perplexity
        # scores = evaluator.run_all_evals(trainer)
        scores = {}
        # print / JSON log
        for k, v in scores.items():
            logger.info("%s -> %.6f" % (k, v))
        if params.is_master:
            logger.info("__log__:%s" % json.dumps(scores))

        # end of epoch
        trainer.save_best_model(scores)
        trainer.save_periodic()
        trainer.end_epoch(scores)
Esempio n. 5
0
def main(params):

    # initialize the multi-GPU / multi-node training
    init_distributed_mode(params)

    # initialize the experiment
    logger = initialize_exp(params)

    # initialize SLURM signal handler for time limit / pre-emption
    init_signal_handler()

    # load data
    data = load_data(params)

    # build model
    if params.encoder_only:
        model = build_model(params, data['dico'])
    else:
        encoder, decoder = build_model(params, data['dico'])

    # build trainer, reload potential checkpoints / build evaluator
    if params.encoder_only:
        trainer = SingleTrainer(model, data, params)
        evaluator = SingleEvaluator(trainer, data, params)
    else:
        trainer = EncDecTrainer(encoder, decoder, data, params)
        evaluator = EncDecEvaluator(trainer, data, params)

    # evaluation
    if params.eval_only:
        scores = evaluator.run_all_evals(trainer)
        for k, v in scores.items():
            logger.info("%s -> %.6f" % (k, v))
        logger.info("__log__:%s" % json.dumps(scores))
        exit()

    # set sampling probabilities for training
    set_sampling_probs(data, params)

    # language model training
    for _ in range(params.max_epoch):

        logger.info("============ Starting epoch %i ... ============" %
                    trainer.epoch)

        trainer.n_sentences = 0

        while trainer.n_sentences < trainer.epoch_size:

            # CLM steps
            for lang1, lang2 in shuf_order(params.clm_steps, params):
                trainer.clm_step(lang1, lang2, params.lambda_clm)

            # MLM steps (also includes TLM if lang2 is not None)
            for lang1, lang2 in shuf_order(params.mlm_steps, params):
                trainer.mlm_step(lang1, lang2, params.lambda_mlm)

            # denoising auto-encoder steps
            for lang in shuf_order(params.ae_steps):
                trainer.mt_step(lang, lang, params.lambda_ae)

            # machine translation steps
            for lang1, lang2 in shuf_order(params.mt_steps, params):
                trainer.mt_step(lang1, lang2, params.lambda_mt)

            # back-translation steps
            for lang1, lang2, lang3 in shuf_order(params.bt_steps):
                trainer.bt_step(lang1, lang2, lang3, params.lambda_bt,
                                params.bt_sample_temperature)

            trainer.iter()

        logger.info("============ End of epoch %i ============" %
                    trainer.epoch)

        # evaluate perplexity
        scores = evaluator.run_all_evals(trainer)

        # print / JSON log
        for k, v in scores.items():
            logger.info("%s -> %.6f" % (k, v))
        if params.is_master:
            logger.info("__log__:%s" % json.dumps(scores))

        # end of epoch
        if params.validation_metrics != '':
            trainer.save_best_model(scores)
        trainer.save_periodic()
        trainer.end_epoch(scores)
Esempio n. 6
0
def load_train(ARGS,save_path,gamename,logfile):

    env = gym.make(ARGS.gamename)
    ARGS.check_env(env)

    # set seeds
    env.seed(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)

    model_best = build_model(ARGS)
    #print("resnet50 have {} paramerters in total".format(sum(x.numel() for x in model_best.parameters())))
    best_test_score = Small_value
    best_kid_mean = Small_value

    # init ARGS's parameter
    ARGS.set_logger(logger)
    ARGS.init(model_best.get_size(), params)
    ARGS.output()

    model_list = [build_model(ARGS) for i in range(ARGS.lam)]
    #model_list = [load('Seaquest-v0RUN-phi-0.001-seed-534484best_model.pt') for i in range(ARGS.lam)]
    sigma_list = [build_sigma(model_best, ARGS) for i in range(ARGS.lam)]
    model_optimizer_list = [SGD(model_best.named_parameters(), ARGS.lr_mean) for i in range(ARGS.lam)]
    sigma_optimizer_list = [SGD(model_best.named_parameters(), ARGS.lr_sigma, sigma=True) for i in range(ARGS.lam)]
    for model in model_list:
        model.set_parameter_no_grad()

    pool = mp.Pool(processes=ARGS.ncpu)

    refer_batch_torch = None
    if ARGS.env_type == "atari":
        # get reference batch
        logger.info("start testing reference batch statistic")
        reference_batch = explore_for_vbn(env, 0.01, ARGS)
        refer_batch_torch = torch.zeros((ARGS.refer_batch_size, 4, 84, 84))
        for i in range(ARGS.refer_batch_size):
            refer_batch_torch[i] = reference_batch[i]


    timestep_count = 0
    test_rewards_list = []
    break_training = False
    all_zero_count = 0
    for g in range(1):
        # fitness evaluation times
        rewards_list = [ [0] * ARGS.population_size ] * len(model_list)
        v = []
        seed_list = [v for i in range(len(model_list))]
        for i in range(ARGS.ft):
            one_rewards_list, one_seed_list, frame_count = train_simulate(model_list, sigma_list, pool, env, ARGS, refer_batch_torch)
            timestep_count += frame_count
            #logger.info("train_simulate:%s" % str(i))
            rewards_list += np.array(one_rewards_list)
            #logger.info("rewardkist%s"%str(i))
            for j,seed in enumerate(one_seed_list):
                seed_list[j].append(seed)
            #logger.info("seed%s"% str(i))
        rewards_list = rewards_list / ARGS.ft
    rewards_mean_list=[]
    for i in range(len(rewards_list)):
        rewards_mean_ = np.mean(np.array(rewards_list[i]))
        rewards_mean_list.append(rewards_mean_)

    with open(logfile,'a') as f:
        f.write(str(rewards_mean_list))
           
        

    # ---------------SAVE---------
    pool.close()
    pool.join()
Esempio n. 7
0
    mp.set_sharing_strategy("file_system")
    # log and save path setting
    torch.set_num_threads(1)
    # torch.manual_seed(int(time.time()))
    model_path = "log/2020-3-4-4/Qbert-phi-0.001-lam-5-mu-1526.pt"
    logfile = model_path[0:-3]+'.txt'

    gamename = "Qbert"
    
    ARGS.gamename = gamename + "NoFrameskip-v4"
    env = gym.make(ARGS.gamename)
    env.seed(int(time.time()))
    ARGS.action_n = env.action_space.n
    
    
    model = build_model(ARGS)
    model.load_state_dict(torch.load(model_path))
    pool = mp.Pool(processes=5)

    refer_batch_torch = None
    if ARGS.env_type == "atari":
        # get reference batch
        reference_batch = explore_for_vbn(env, 0.01, ARGS)
        refer_batch_torch = torch.zeros((ARGS.refer_batch_size, 4, 84, 84))
        for i in range(ARGS.refer_batch_size):
            refer_batch_torch[i] = reference_batch[i]

    test_rewards,test_timestep,test_noop_list_,_= test(model,pool,env,ARGS,refer_batch_torch,test_times=200)
    test_rewards_mean = np.mean(np.array(test_rewards))

    with open(logfile,'a') as f:
Esempio n. 8
0
def main(params):

    # initialize the experiment
    logger = initialize_exp(params)

    # load dataOld
    data = load_data(params)
    # check_vocab(dataOld)

    # build model
    if params.encoder_only:
        model = build_model(params, data['source_dico'])
    else:
        encoder, decoder = build_model(params, data['source_dico'],
                                       data['target_dico'])

    # build trainer, reload potential checkpoints / build evaluator
    if params.encoder_only:
        trainer = SingleTrainer(model, data, params)
        evaluator = SingleEvaluator(trainer, data, params)
    else:
        trainer = EncDecTrainer(encoder, decoder, data, params)
        evaluator = EncDecEvaluator(trainer, data, params)

    # evaluation
    if params.eval_only:
        scores = evaluator.run_all_evals(trainer)
        for k, v in scores.items():
            logger.info("%s -> %.6f" % (k, v))
        logger.info("__log__:%s" % json.dumps(scores))
        exit()

    # language model training
    for _ in range(params.max_epoch):

        logger.info("============ Starting epoch %i ... ============" %
                    trainer.epoch)

        trainer.n_iter = 0

        while trainer.n_iter < trainer.epoch_size:
            if params.cs_step:
                trainer.content_selection_step(params.lambda_cs)
            if params.sm_step:
                trainer.summarization_step(params.lambda_sm)
            if params.lm_step:
                trainer.clm_step(params.lambda_lm)
            trainer.iter()
        logger.info("============ End of epoch %i ============" %
                    trainer.epoch)

        # evaluate perplexity
        scores = evaluator.run_all_evals(trainer)
        # print / JSON log
        for k, v in scores.items():
            logger.info("%s -> %.6f" % (k, v))
        logger.info("__log__:%s" % json.dumps(scores))

        # end of epoch
        trainer.save_best_model(scores)
        trainer.save_periodic()
        trainer.end_epoch()
Esempio n. 9
0
def main(params):

    # initialize the multi-GPU / multi-node training
    init_distributed_mode(params)

    # initialize the experiment
    logger = initialize_exp(params)

    # initialize SLURM signal handler for time limit / pre-emption
    init_signal_handler()

    # load data
    data = load_data(params)

    # build model
    model = build_model(params, data['dico'])

    # build trainer, reload potential checkpoints / build evaluator
    trainer = Trainer(model, data, params)
    evaluator = Evaluator(trainer, data, params)

    # evaluation
    if params.eval_only:
        scores = evaluator.run_all_evals(trainer)
        for k, v in scores.items():
            logger.info("%s -> %.6f" % (k, v))
        logger.info("__log__:%s" % json.dumps(scores))
        exit()

    # set sampling probabilities for training
    set_sampling_probs(data, params)

    # language model training
    for _ in range(params.max_epoch):

        logger.info("============ Starting epoch %i ... ============" %
                    trainer.epoch)

        trainer.n_sentences = 0

        while trainer.n_sentences < trainer.epoch_size:
            # MLM steps
            trainer.mlm_step(params.lambda_mlm)

            trainer.iter()

        logger.info("============ End of epoch %i ============" %
                    trainer.epoch)

        # evaluate perplexity
        scores = evaluator.run_all_evals(trainer)

        # print / JSON log
        for k, v in scores.items():
            logger.info("%s -> %.6f" % (k, v))
        if params.is_master:
            logger.info("__log__:%s" % json.dumps(scores))

        # end of epoch
        trainer.save_best_model(scores)
        trainer.save_periodic()
        trainer.end_epoch(scores)
Esempio n. 10
0
def main(params):

    # initialize the multi-GPU / multi-node training
    init_distributed_mode(params)

    # initialize the experiment
    logger = initialize_exp(params)

    # initialize SLURM signal handler for time limit / pre-emption
    init_signal_handler()

    # load data
    data = load_data(params)

    # build the big model
    if params.encoder_only:
        big_model = build_model(params, data['dico'], cut=False)
    else:
        # 修改处1
        big_encoder, big_decoder = build_model(params, data['dico'], cut=False)

    # if we cut some layers, must build a small model
    if params.cut_layer:
        if params.encoder_only:
            small_model = build_model(params, data['dico'], cut=True)
        else:
            # 修改处1
            small_encoder, small_decoder = build_model(params,
                                                       data['dico'],
                                                       cut=True)

    # build the big trainer, reload potential checkpoints
    # the big trainer is used to train, so need't a evaluator for it
    if params.encoder_only:
        big_trainer = SingleTrainer(big_model, data, params)
    else:
        big_trainer = EncDecTrainer(big_encoder, big_decoder, data, params)

    params.lambda_mlm = "1"
    params.lambda_clm = "1"
    params.lambda_pc = "1"
    params.lambda_ae = "1"
    params.lambda_mt = "1"
    params.lambda_bt = "1"

    # build the small model, and use it for evaluator
    if params.encoder_only:
        small_trainer = small_SingleTrainer(small_model, data, params)
        evaluator = SingleEvaluator(small_trainer, data, params)
    else:
        small_trainer = small_EncDecTrainer(small_encoder, small_decoder, data,
                                            params)
        evaluator = EncDecEvaluator(small_trainer, data, params)

    # evaluation only for the small trainer
    if params.eval_only:
        scores = evaluator.run_all_evals(small_trainer)
        for k, v in scores.items():
            logger.info("%s -> %.6f" % (k, v))
        logger.info("__log__:%s" % json.dumps(scores))
        exit()

    # set sampling probabilities for training
    set_sampling_probs(data, params)

    # language model training
    for count in range(params.max_epoch):

        logger.info("============ Starting epoch %i ... ============" %
                    small_trainer.epoch)

        small_trainer.n_sentences = 0

        while small_trainer.n_sentences < small_trainer.epoch_size:

            # CLM steps
            for lang1, lang2 in shuf_order(params.clm_steps, params):
                small_trainer.clm_step(lang1, lang2, params.lambda_clm,
                                       big_trainer)
            # MLM steps (also includes TLM if lang2 is not None)
            for lang1, lang2 in shuf_order(params.mlm_steps, params):
                small_trainer.mlm_step(lang1, lang2, params.lambda_mlm,
                                       big_trainer)

            # parallel classification steps
            for lang1, lang2 in shuf_order(params.pc_steps, params):
                small_trainer.pc_step(lang1, lang2, params.lambda_pc)

            # denoising auto-encoder steps
            for lang in shuf_order(params.ae_steps):
                small_trainer.mt_step(lang, lang, params.lambda_ae,
                                      big_trainer)

            # machine translation steps
            for lang1, lang2 in shuf_order(params.mt_steps, params):
                small_trainer.mt_step(lang1, lang2, params.lambda_mt,
                                      big_trainer)

            # back-translation steps
            for lang1, lang2, lang3 in shuf_order(params.bt_steps):
                small_trainer.bt_step(lang1, lang2, lang3, params.lambda_bt)

            small_trainer.iter()

        logger.info("============ End of epoch %i ============" %
                    small_trainer.epoch)

        # evaluate perplexity
        scores = evaluator.run_all_evals(small_trainer)

        # print / JSON log
        for k, v in scores.items():
            logger.info("%s -> %.6f" % (k, v))
        if params.is_master:
            logger.info("__log__:%s" % json.dumps(scores))

        # end of epoch
        small_trainer.save_best_model(scores)
        small_trainer.save_periodic()
        small_trainer.end_epoch(scores)