Beispiel #1
0
def main(prefix="",
         url_feature="",
         url_pred="",
         url_len="",
         url_weight="",
         batch_size=126,
         max_input_len=30,
         max_sent_length=24,
         embed_size=13,
         acc_range=10,
         sight=1,
         is_classify=0,
         decoder=1,
         decoder_size=4,
         loss='mae',
         context_meaning=1,
         rnn_layer=1):
    # init model
    model = Model(max_input_len=max_input_len,
                  max_sent_len=max_sent_length,
                  embed_size=embed_size,
                  using_bidirection=False,
                  fw_cell="basic",
                  bw_cell="basic",
                  batch_size=batch_size,
                  is_classify=is_classify,
                  use_tanh_prediction=True,
                  target=5 if is_classify else 1,
                  loss=loss,
                  acc_range=acc_range,
                  input_rnn=False,
                  sight=sight,
                  use_decoder=decoder,
                  dvs=decoder_size,
                  rnn_layer=rnn_layer)

    # model.init_data_node()
    tf.reset_default_graph()
    with tf.device('/%s' % p.device):
        model.init_ops()
        saver = tf.train.Saver()

    utils.assert_url(url_feature)
    if url_pred:
        utils.assert_url(url_pred)
        dataset = utils.load_file(url_feature)
        pred = utils.load_file(url_pred, False)
        if is_classify:
            pred = [
                utils.get_pm25_class(round(float(x.replace("\n", ""))))
                for x in pred
            ]
        else:
            pred = [round(float(x.replace("\n", ""))) for x in pred]
        if max_input_len > 1:
            utils.assert_url(url_len)
            data_len = utils.load_file(url_len)
        else:
            data_len = None
        _, test = utils.process_data(dataset,
                                     data_len,
                                     pred,
                                     batch_size,
                                     max_input_len,
                                     max_sent_length,
                                     True,
                                     sight,
                                     context_meaning=context_meaning)
    else:
        test = utils.load_file(url_feature)

    tconfig = tf.ConfigProto(allow_soft_placement=True)

    with tf.Session(config=tconfig) as session:
        init = tf.global_variables_initializer()
        session.run(init)
        # saver = tf.train.import_meta_graph(url_weight + ".meta")
        saver.restore(session, url_weight)
        print('==> running model')
        _, _, preds, lb = model.run_epoch(session, test, shuffle=False)
        preds = [x if x <= 45 else (x + 10) for x in preds]
        lb = lb[0:len(preds)]
        # print('Validation loss: {}'.format(valid_loss))
        # print('Validation accuracy: {}'.format(valid_accuracy))
        # tmp = 'Test validation accuracy: %.4f' % valid_accuracy
        # utils.save_file("test_acc/%s_test_acc.txt" % prefix, tmp, False)
        evaluate(preds, lb, acc_range, is_classify)
        utils.save_predictions(preds, lb, p.test_preds % prefix)
Beispiel #2
0
def run_neural_nets(url_feature="", attention_url="", url_weight="sp", encoder_length=24, encoder_size=15, decoder_length=8, decoder_size=9, is_test=False, restore=False, model="NN", pre_train=False):
    if model == "NN":
        model = NeuralNetwork(encoder_length=encoder_length, encoder_vector_size=encoder_size, decoder_length=decoder_length, decoder_vector_size=decoder_size)
    elif model == "SAE":
        model = StackAutoEncoder(encoder_length=encoder_length, encoder_vector_size=encoder_size, decoder_length=decoder_length, pre_train=pre_train)
    else:
        model = Adain(encoder_length=encoder_length, encoder_vector_size=encoder_size, decoder_length=decoder_length)
    print('==> initializing models')
    with tf.device('/%s' % p.device):
        model.init_model()
        init = tf.global_variables_initializer()
        saver = tf.train.Saver()
    utils.assert_url(url_feature)

    tconfig = get_gpu_options()
    sum_dir = 'summaries'
    if not utils.check_file(sum_dir):
        os.makedirs(sum_dir)

    train_writer = None
    with tf.Session(config=tconfig) as session:
        if not restore:
            session.run(init)
        else:
            print("==> Reload pre-trained weights")
            saver.restore(session, url_weight)
            url_weight = url_weight.split("/")[-1]
            url_weight = url_weight.rstrip(".weights")
        
        if not is_test:
            suf = time.strftime("%Y.%m.%d_%H.%M")
            train_writer = tf.summary.FileWriter(sum_dir + "/" + url_weight + "_train", session.graph, filename_suffix=suf)
            valid_writer = tf.summary.FileWriter(sum_dir + "/" + url_weight + "_valid", session.graph, filename_suffix=suf)

        print("==> Loading dataset")
        dataset = utils.load_file(url_feature)
        if dataset:
            dataset = np.asarray(dataset, dtype=np.float32)
            lt = len(dataset)
            st = int(lt/2)
            lt = lt - st
            dataset = dataset[st:,:,:]
            train, valid = utils.process_data_grid(lt, p.batch_size, encoder_length, decoder_length, is_test)
            if attention_url:
                attention_data = utils.load_file(attention_url)
            else:
                attention_data = None
            model.set_data(dataset, train, valid, attention_data, session)
            if not is_test:
                best_val_epoch = 0
                best_val_loss = float('inf')
                # best_overall_val_loss = float('inf')
                print('==> starting training')
                for epoch in xrange(p.total_iteration):
                    print('Epoch {}'.format(epoch))
                    start = time.time()
                    train_loss, _ = model.run_epoch(session, train, epoch, train_writer, train_op=model.train_op, train=True)
                    print('Training loss: {}'.format(train_loss))

                    valid_loss, _ = model.run_epoch(session, valid, epoch, valid_writer)
                    print('Validation loss: {}'.format(valid_loss))

                    if valid_loss < best_val_loss:
                        best_val_loss = valid_loss
                        best_val_epoch = epoch
                        print('Saving weights')
                        saver.save(session, 'weights/%s.weights' % url_weight)

                    if (epoch - best_val_epoch) > p.early_stopping:
                        break
                    print('Total time: {}'.format(time.time() - start))
            else:
                # saver.restore(session, url_weight)
                print('==> running model')
                _, preds = model.run_epoch(session, model.train, shuffle=False)
                pt = re.compile("weights/([A-Za-z0-9_.]*).weights")
                name = pt.match(url_weight)
                if name:
                    name_s = name.group(1)
                else:
                    name_s = url_weight
                utils.save_file("test_sp/%s" % name_s, preds)
Beispiel #3
0
def train_baseline(url_feature="", attention_url="", url_weight="sp", batch_size=128, encoder_length=24, embed_size=None, loss=None, decoder_length=24, decoder_size=4, grid_size=25, rnn_layers=1, dtype="grid", is_folder=False, is_test=False, use_cnn=True, restore=False, model_name="", validation_url="", attention_valid_url="", best_val_loss=None, forecast_factor=0):
    if model_name == "APNET":
        model = APNet(encoder_length=encoder_length, encode_vector_size=embed_size, batch_size=batch_size, decoder_length=decoder_length, decode_vector_size=decoder_size, grid_size=grid_size,  use_attention=bool(attention_url), forecast_factor=forecast_factor)
    elif model_name == "TNET": 
        model = TNet(encoder_length=8, decoder_length=8, grid_size=32)
    elif model_name == "TNETLSTM": 
        model = TNetLSTM(encoder_length=8, decoder_length=8, grid_size=32)
    elif model_name == "SRCN":
        model = SRCN(encoder_length=8, decoder_length=8, grid_size=32)
    else:
        model = BaselineModel(encoder_length=encoder_length, encode_vector_size=embed_size, batch_size=batch_size, decoder_length=decoder_length, decode_vector_size=decoder_size, rnn_layers=rnn_layers, dtype=dtype, grid_size=grid_size, use_cnn=use_cnn, loss=loss, use_attention=bool(attention_url), forecast_factor=forecast_factor)
    print('==> initializing models')
    with tf.device('/%s' % p.device):
        model.init_ops(is_train=(not is_test))
        init = tf.global_variables_initializer()
        saver = tf.train.Saver()
    utils.assert_url(url_feature)
    tconfig = get_gpu_options()
    sum_dir = 'summaries'
    if not utils.check_file(sum_dir):
        os.makedirs(sum_dir)

    train_writer = None
    valid_writer = None
    with tf.Session(config=tconfig) as session:
        if not restore:
            session.run(init)
        else:
            print("==> Reload pre-trained weights")
            saver.restore(session, url_weight)
            url_weight = url_weight.split("/")[-1]
            url_weight = url_weight.rstrip(".weights")
        
        if not is_test:
            # suf = time.strftime("%Y.%m.%d_%H.%M")
            csn = int(time.time())
            train_writer = tf.summary.FileWriter(sum_dir + "/" + url_weight + "_train_" + str(csn), session.graph)
            valid_writer = tf.summary.FileWriter(sum_dir + "/" + url_weight + "_valid_" + str(csn), session.graph)
            if restore:
                url_weight = url_weight + "_" + str(csn)
       
        folders = None
        best_val_loss = float('inf')
        if is_folder:
            folders = sorted(os.listdir(url_feature))
            if attention_url:
                a_folders = sorted(os.listdir(attention_url))
                folders = zip(folders, a_folders)
            last_epoch = 0
            for i, files in enumerate(folders):
                if attention_url:
                    x, y = files
                    att_url = os.path.join(attention_url, y)
                    print("==> Training set (%i, %s, %s)" % (i + 1, x, y))
                else: 
                    x = files
                    att_url = ""
                    print("==> Training set (%i, %s)" % (i + 1, x))
                last_epoch, best_val_loss = execute(os.path.join(url_feature, x), att_url, url_weight, model, session, saver, batch_size, encoder_length, decoder_length, 
                                    is_test, (train_writer, valid_writer), last_epoch, validation_url=validation_url, attention_valid_url=attention_valid_url, best_val_loss=best_val_loss)
                if not is_test:
                    print("best val loss:" + str(best_val_loss))
        else:
            _, best_val_loss = execute(url_feature, attention_url, url_weight, model, session, saver, batch_size, encoder_length, decoder_length, is_test, (train_writer, valid_writer), 
                        validation_url=validation_url, attention_valid_url=attention_valid_url, best_val_loss=best_val_loss)
            if not is_test:
                print("best val loss:" + str(best_val_loss))
Beispiel #4
0
def main(prefix="",
         url_feature="",
         url_pred="",
         url_len="",
         url_feature1="",
         url_pred1="",
         url_len1="",
         batch_size=126,
         max_input_len=30,
         max_sent_length=24,
         lr_decayable=False,
         using_bidirection=False,
         forward_cell='',
         backward_cell='',
         embed_size=None,
         is_classify=True,
         loss=None,
         acc_range=None,
         usp=None,
         input_rnn=None,
         reload_data=True,
         pred_sight=1,
         decoder=1,
         decoder_size=4,
         is_weighted=0,
         context_info=1,
         rnn_layer=1):
    target = 5 if is_classify else 1
    model = Model(max_input_len=max_input_len,
                  max_sent_len=max_sent_length,
                  embed_size=embed_size,
                  learning_rate=0.001,
                  lr_decayable=lr_decayable,
                  using_bidirection=using_bidirection,
                  fw_cell=forward_cell,
                  bw_cell=backward_cell,
                  batch_size=batch_size,
                  target=target,
                  is_classify=is_classify,
                  loss=loss,
                  acc_range=acc_range,
                  use_tanh_prediction=usp,
                  input_rnn=input_rnn,
                  sight=pred_sight,
                  dvs=decoder_size,
                  use_decoder=decoder,
                  is_weighted=is_weighted,
                  rnn_layer=rnn_layer)
    # model.init_data_node()
    with tf.device('/%s' % p.device):
        model.init_ops()
        print('==> initializing variables')
        init = tf.global_variables_initializer()
        saver = tf.train.Saver()

    if reload_data:
        print("Loading dataset")
        utils.assert_url(url_feature)
        utils.assert_url(url_pred)
        if max_input_len > 1:
            utils.assert_url(url_len)
            data_len = utils.load_file(url_len)
        else:
            data_len = None
        # if url_feature1:
        #     utils.assert_url(url_feature1)
        #     utils.assert_url(url_pred1)
        #     if max_input_len > 1:
        #         utils.assert_url(url_len1)
        dataset = utils.load_file(url_feature)
        pred = utils.load_file(url_pred, False)
        if is_classify:
            pred = [
                utils.get_pm25_class(round(float(x.replace("\n", ""))))
                for x in pred
            ]
        else:
            pred = [round(float(x.replace("\n", ""))) for x in pred]
        train, dev = utils.process_data(dataset, data_len, pred, batch_size,
                                        max_input_len, max_sent_length, False,
                                        pred_sight, context_info)
        # utils.save_file(p.train_url % ("_" + prefix + "_" + str(max_sent_length)), train)
        # utils.save_file(p.dev_url % ("_" + prefix + "_" +str(max_sent_length)), dev)
    else:
        utils.assert_url(p.train_url)
        utils.assert_url(p.dev_url)
        train = utils.load_file(p.train_url)
        dev = utils.load_file(p.dev_url)
    model.set_data(train, dev)

    gpu_options = None
    if p.device == "gpu":
        gpu_options = tf.GPUOptions(
            per_process_gpu_memory_fraction=p.gpu_fraction)

    tconfig = tf.ConfigProto(allow_soft_placement=True,
                             gpu_options=gpu_options)

    with tf.Session(config=tconfig) as session:

        sum_dir = 'summaries/' + time.strftime("%Y-%m-%d %H %M")
        if not utils.check_file(sum_dir):
            os.makedirs(sum_dir)
        train_writer = tf.summary.FileWriter(sum_dir, session.graph)

        session.run(init)

        best_val_epoch = 0
        prev_epoch_loss = float('inf')
        best_val_loss = float('inf')
        best_val_accuracy = 0.0
        best_overall_val_loss = float('inf')

        if url_pred1:
            print('==> restoring weights')
            saver.restore(session, '%s' % url_pred1)

        print('==> starting training')
        train_losses, train_accuracies = [], []
        val_losses, val_acces, best_preds, best_lb = [], [], [], []
        for epoch in xrange(p.total_iteration):
            print('Epoch {}'.format(epoch))
            start = time.time()

            train_loss, train_accuracy, _, _ = model.run_epoch(
                session,
                model.train,
                epoch,
                train_writer,
                train_op=model.train_step,
                train=True)
            train_losses.append(train_loss)
            train_accuracies.append(train_accuracy)
            print('Training loss: {}'.format(train_loss))
            # print('Training accuracy: {}'.format(train_accuracy))

            valid_loss, valid_accuracy, preds, lb = model.run_epoch(
                session, model.valid)
            val_losses.append(valid_loss)
            val_acces.append(valid_accuracy)
            print('Validation loss: {}'.format(valid_loss))
            # print('Validation accuracy: {}'.format(valid_accuracy))

            if valid_loss < best_val_loss:
                best_val_loss = valid_loss
                best_val_epoch = epoch
                if best_val_loss < best_overall_val_loss:
                    print('Saving weights')
                    best_overall_val_loss = best_val_loss
                    saver.save(session, 'weights/%sdaegu.weights' % prefix)
                    best_preds = preds
                    best_lb = lb
                    # utils.save_predictions(best_preds, best_lb, p.train_preds % prefix)
                    if best_val_accuracy < valid_accuracy:
                        best_val_accuracy = valid_accuracy

            if (epoch - best_val_epoch) > p.early_stopping:
                break
            print('Total time: {}'.format(time.time() - start))
        # tmp = 'Best validation accuracy: %.4f' % best_val_accuracy
        # print(tmp)
        # utils.save_file("accuracies/%saccuracy.txt" % prefix, tmp, False)
        utils.save_file(
            "logs/%slosses.pkl" % prefix, {
                "train_loss": train_losses,
                "train_acc": train_accuracies,
                "valid_loss": val_losses,
                "valid_acc": val_acces
            })
Beispiel #5
0
def train_gan(url_feature="",
              attention_url="",
              url_weight="sp",
              batch_size=128,
              encoder_length=24,
              embed_size=None,
              decoder_length=24,
              decoder_size=4,
              grid_size=25,
              is_folder=False,
              is_test=False,
              restore=False,
              model_name="APGAN",
              forecast_factor=0):
    if model_name == "APGAN":
        model = APGan(encoder_length=encoder_length,
                      encode_vector_size=embed_size,
                      batch_size=batch_size,
                      decode_vector_size=decoder_size,
                      decoder_length=decoder_length,
                      grid_size=grid_size,
                      forecast_factor=forecast_factor,
                      use_attention=bool(attention_url))
    elif model_name == "MASKGAN":
        model = MaskGan(encoder_length=encoder_length,
                        encode_vector_size=embed_size,
                        batch_size=batch_size,
                        decode_vector_size=decoder_size,
                        grid_size=grid_size,
                        use_cnn=1)
    elif model_name == "APGAN_LSTM":
        model = APGAN_LSTM(encoder_length=encoder_length,
                           encode_vector_size=embed_size,
                           batch_size=batch_size,
                           decode_vector_size=decoder_size,
                           decoder_length=decoder_length,
                           grid_size=grid_size)
    elif model_name == "CAPGAN":
        model = CAPGan(encoder_length=encoder_length,
                       encode_vector_size=embed_size,
                       batch_size=batch_size,
                       decode_vector_size=decoder_size,
                       grid_size=grid_size)
    elif model_name == "TGAN":
        model = TGAN(encoder_length=8, decoder_length=8, grid_size=32)
    else:
        model = TGANLSTM(encoder_length=8, decoder_length=8, grid_size=32)
    tconfig = get_gpu_options()
    utils.assert_url(url_feature)
    if not utils.check_file('summaries'):
        os.makedirs('summaries')
    print('==> initializing models')
    with tf.device('/%s' % p.device):
        model.init_ops(not is_test)
        init = tf.global_variables_initializer()
        saver = tf.train.Saver()
    train_writer = None
    with tf.Session(config=tconfig) as session:
        if not restore:
            session.run(init)
        else:
            print("==> Reload pre-trained weights")
            saver.restore(session, url_weight)
        csn = int(time.time())
        if not is_test:
            url_weight = url_weight.split("/")[-1]
            url_weight = url_weight.rstrip(".weights")
            train_writer = tf.summary.FileWriter(
                "summaries/%s_%i" % (url_weight, csn), session.graph)
        folders = None
        if is_folder:
            folders = os.listdir(url_feature)
            folders = sorted(folders)
            if attention_url:
                a_folders = sorted(os.listdir(attention_url))
                folders = zip(folders, a_folders)
            for i, files in enumerate(folders):
                if attention_url:
                    x, y = files
                    att_url = os.path.join(attention_url, y)
                    print("==> Training set (%i, %s, %s)" % (i + 1, x, y))
                else:
                    att_url = None
                    x = files
                    print("==> Training set (%i, %s)" % (i + 1, x))
                execute_gan(os.path.join(url_feature, x), att_url, url_weight,
                            model, session, saver, batch_size, encoder_length,
                            decoder_length, is_test, train_writer,
                            i * p.total_iteration)
        else:
            execute_gan(url_feature, attention_url, url_weight, model, session,
                        saver, batch_size, encoder_length, decoder_length,
                        is_test, train_writer)
Beispiel #6
0
def train_baseline(url_feature="",
                   attention_url="",
                   url_weight="sp",
                   batch_size=128,
                   encoder_length=24,
                   embed_size=None,
                   loss=None,
                   decoder_length=24,
                   decoder_size=4,
                   grid_size=25,
                   rnn_layers=1,
                   dtype="grid",
                   is_folder=False,
                   is_test=False,
                   use_cnn=True,
                   restore=False,
                   model_name="",
                   validation_url="",
                   attention_valid_url="",
                   best_val_loss=None,
                   forecast_factor=0,
                   encoder_type=3,
                   atttention_hidden_size=17,
                   use_gen_cnn=True,
                   label_path="",
                   num_class=0,
                   districts=25):
    if model_name == "APNET":
        model = APNet(encoder_length=encoder_length,
                      encode_vector_size=embed_size,
                      batch_size=batch_size,
                      decoder_length=decoder_length,
                      decode_vector_size=decoder_size,
                      grid_size=grid_size,
                      use_attention=bool(attention_url),
                      forecast_factor=forecast_factor,
                      mtype=encoder_type,
                      atttention_hidden_size=atttention_hidden_size,
                      use_gen_cnn=args.use_gen_cnn,
                      num_class=num_class,
                      districts=districts)
    elif model_name == "APNET_CHINA":
        model = APNetChina(encoder_length=encoder_length,
                           encode_vector_size=embed_size,
                           batch_size=batch_size,
                           decoder_length=decoder_length,
                           grid_size=grid_size,
                           use_attention=bool(attention_url),
                           forecast_factor=forecast_factor,
                           atttention_hidden_size=atttention_hidden_size,
                           num_class=num_class,
                           districts=districts)
    elif model_name == "TNET":
        model = TNet(encoder_length=8, decoder_length=8, grid_size=32)
    elif model_name == "TNETLSTM":
        model = TNetLSTM(encoder_length=8, decoder_length=8, grid_size=32)
    elif model_name == "SRCN":
        model = SRCN(encoder_length=8, decoder_length=8, grid_size=32)
    else:
        model = BaselineModel(encoder_length=encoder_length,
                              encode_vector_size=embed_size,
                              batch_size=batch_size,
                              decoder_length=decoder_length,
                              decode_vector_size=decoder_size,
                              rnn_layers=rnn_layers,
                              dtype=dtype,
                              grid_size=grid_size,
                              use_cnn=use_cnn,
                              loss=loss,
                              use_attention=bool(attention_url),
                              forecast_factor=forecast_factor,
                              mtype=encoder_type,
                              atttention_hidden_size=atttention_hidden_size,
                              use_gen_cnn=args.use_gen_cnn,
                              num_class=num_class,
                              districts=districts)
    print('==> initializing models')
    with tf.device('/%s' % p.device):
        model.init_ops(is_train=(not is_test))
        init = tf.global_variables_initializer()
        saver = tf.train.Saver()
    utils.assert_url(url_feature)
    tconfig = get_gpu_options()
    sum_dir = 'summaries'
    if not utils.check_file(sum_dir):
        os.makedirs(sum_dir)

    train_writer = None
    valid_writer = None
    with tf.Session(config=tconfig) as session:
        if not restore:
            session.run(init)
        else:
            print("==> Reload pre-trained weights")
            # var = [v for v in tf.get_default_graph().as_graph_def().node if not "decoder_output_class" in v.name and not "device" in v.name and not "Adam" in v.name]
            if not is_test and model_name == "APNET_CHINA":
                # loading from seoul to china
                var = [
                    v for v in tf.global_variables()
                    if not "decoder_output" in v.name
                    and not "embedding" in v.name
                ]
            else:
                print("testing loading weights")
                var = [
                    v for v in tf.global_variables()
                    if not "embedding" in v.name
                ]
            saver = tf.train.Saver(var)
            saver.restore(session, url_weight)
            url_weight = url_weight.split("/")[-1]
            url_weight = url_weight.rstrip(".weights")
            # if not is_test:
            initialize_uninitialized(session)
        # regular setting saver = tf.train.Saver()
        # "should not save embedding in weights"
        var = [v for v in tf.global_variables() if not "embedding" in v.name]
        saver = tf.train.Saver(var)

        if not is_test:
            # suf = time.strftime("%Y.%m.%d_%H.%M")
            csn = int(time.time())
            train_writer = tf.summary.FileWriter(
                sum_dir + "/" + url_weight + "_train_" + str(csn),
                session.graph)
            valid_writer = tf.summary.FileWriter(
                sum_dir + "/" + url_weight + "_valid_" + str(csn),
                session.graph)
            if restore:
                url_weight = url_weight + "_" + str(csn)

        folders = None
        best_val_loss = float('inf')
        if is_folder:
            folders = sorted(os.listdir(url_feature))
            if attention_url:
                a_folders = sorted(os.listdir(attention_url))
                folders = zip(folders, a_folders)
            last_epoch = 0
            for i, files in enumerate(folders):
                if attention_url:
                    x, y = files
                    att_url = os.path.join(attention_url, y)
                    print("==> Training set (%i, %s, %s)" % (i + 1, x, y))
                else:
                    x = files
                    att_url = ""
                    print("==> Training set (%i, %s)" % (i + 1, x))
                last_epoch, best_val_loss = execute(
                    os.path.join(url_feature, x),
                    att_url,
                    url_weight,
                    model,
                    session,
                    saver,
                    batch_size,
                    encoder_length,
                    decoder_length,
                    is_test, (train_writer, valid_writer),
                    last_epoch,
                    validation_url=validation_url,
                    attention_valid_url=attention_valid_url,
                    best_val_loss=best_val_loss)
                if not is_test:
                    print("best val loss:" + str(best_val_loss))
        else:
            _, best_val_loss = execute(url_feature,
                                       attention_url,
                                       url_weight,
                                       model,
                                       session,
                                       saver,
                                       batch_size,
                                       encoder_length,
                                       decoder_length,
                                       is_test, (train_writer, valid_writer),
                                       validation_url=validation_url,
                                       attention_valid_url=attention_valid_url,
                                       best_val_loss=best_val_loss)
            if not is_test:
                print("best val loss:" + str(best_val_loss))