max_grad_norm=args.max_grad_norm)
    compute_loss_fct = MultipleChoiceLossCompute(criterion, criterion,
                                                 args.lm_coef, model_opt)
    openAIModel = OpenAIModel()
    openAIModel.load_openai_pretrained_model(dh_model.transformer,
                                             n_ctx=n_ctx,
                                             n_special=n_special)

    dh_model.to(device)
    dh_model = nn.DataParallel(dh_model)

    n_updates = 0
    n_epochs = 0
    if dataset != 'stsb':
        trYt = trY
    if submit:
        path = os.path.join(save_dir, desc, 'state_of_module')
        torch.save(dh_model.state_dict(), make_path(path))
    best_score = 0
    for i in range(args.n_iter):
        print("running epoch", i)
        training_engine.run_epoch()
        n_epochs += 1
        training_engine.log(save_dir, desc)
    if submit:
        path = os.path.join(save_dir, desc, 'state_of_module')
        dh_model.load_state_dict(torch.load(path))
        training_engine.predict(dataset, args.submission_dir)
        prediction = Prediction()
        prediction.output_predictions(test_path, pred_path, out_path, topic)
        dh_model.to(device)
        dh_model = nn.DataParallel(dh_model)

        n_updates = 0
        n_epochs = 0
        trYt = trY
        best_score = 0

        for i in range(args.n_iter):
            print("running epoch", i)
            run_epoch()
            n_epochs += 1
            # log(save_dir, desc)

        torch.save(dh_model.state_dict(), 'model_state')

    else:
        n_vocab = len(voc)
        max_len = 140
        n_special = 2
        n_ctx = max_len + 2
        vocab = n_vocab + n_special + n_ctx
        n_batch_train = args.n_batch * max(n_gpu, 1)
        start_token = n_vocab
        clf_token = n_vocab + 1
        dh_model = DoubleHeadModel(args, clf_token, ('classification', 2), vocab, n_ctx)
        dh_model.to(device)
        dh_model = nn.DataParallel(dh_model)

        dh_model.load_state_dict(torch.load('model_state'))