Beispiel #1
0
def run_neural_nets(dataset, url_weight="sp", encoder_length=24, encoder_size=15, decoder_length=8, decoder_size=9, is_test=False, 
                    restore=False, model="NN", pre_train=False, forecast_factor=0):
    tf.reset_default_graph()
    print("training %s with decoder_length = %i" % (model, decoder_length))
    if model == "NN":
        model = NeuralNetwork(encoder_length=encoder_length, encoder_vector_size=encoder_size, decoder_length=decoder_length, decoder_vector_size=decoder_size)
    elif model == "SAE":
        model = StackAutoEncoder(encoder_length=encoder_length, encoder_vector_size=encoder_size, decoder_length=decoder_length, pre_train=pre_train, forecast_factor=forecast_factor)
    else:
        model = Adain(encoder_length=encoder_length, encoder_vector_size=encoder_size, decoder_length=decoder_length, forecast_factor=forecast_factor)
    print('==> initializing models')
    with tf.device('/%s' % p.device):
        model.init_model()
        init = tf.global_variables_initializer()
        saver = tf.train.Saver()
    tconfig = get_gpu_options()
    with tf.Session(config=tconfig) as session:
        if not restore:
            session.run(init)
        else:
            print("==> Reload pre-trained weights")
            saver.restore(session, "weights/%s_%ih.weights" % (url_weight, decoder_length))

        print("==> Loading dataset")
    
        train, valid = utils.process_data_grid(len(dataset), p.batch_size, encoder_length, decoder_length, is_test)
        model.set_data(dataset, train, valid, None, session)
        if not is_test:
            best_val_epoch = 0
            best_val_loss = float('inf')
            print('==> starting training')
            for epoch in xrange(p.total_iteration):
                print('Epoch {}'.format(epoch))
                start = time.time()
                train_loss, _ = model.run_epoch(session, train, epoch, None, train_op=model.train_op, train=True)
                print('Training loss: {}'.format(train_loss))

                valid_loss, _ = model.run_epoch(session, valid, epoch, None)
                print('Validation loss: {}'.format(valid_loss))

                if valid_loss < best_val_loss:
                    best_val_loss = valid_loss
                    best_val_epoch = epoch
                    print('Saving weights')
                    saver.save(session, 'weights/%s_%ih.weights' % (url_weight, decoder_length))

                if (epoch - best_val_epoch) > p.early_stopping:
                    break
                print('Total time: {}'.format(time.time() - start))
        else:
            # saver.restore(session, url_weight)
            print('==> running model')
            _, preds = model.run_epoch(session, model.train, shuffle=False, stride=2)
            return preds
    return None
Beispiel #2
0
def run_neural_nets(url_feature="", attention_url="", url_weight="sp", encoder_length=24, encoder_size=15, decoder_length=8, decoder_size=9, is_test=False, restore=False, model="NN", pre_train=False):
    if model == "NN":
        model = NeuralNetwork(encoder_length=encoder_length, encoder_vector_size=encoder_size, decoder_length=decoder_length, decoder_vector_size=decoder_size)
    elif model == "SAE":
        model = StackAutoEncoder(encoder_length=encoder_length, encoder_vector_size=encoder_size, decoder_length=decoder_length, pre_train=pre_train)
    else:
        model = Adain(encoder_length=encoder_length, encoder_vector_size=encoder_size, decoder_length=decoder_length)
    print('==> initializing models')
    with tf.device('/%s' % p.device):
        model.init_model()
        init = tf.global_variables_initializer()
        saver = tf.train.Saver()
    utils.assert_url(url_feature)

    tconfig = get_gpu_options()
    sum_dir = 'summaries'
    if not utils.check_file(sum_dir):
        os.makedirs(sum_dir)

    train_writer = None
    with tf.Session(config=tconfig) as session:
        if not restore:
            session.run(init)
        else:
            print("==> Reload pre-trained weights")
            saver.restore(session, url_weight)
            url_weight = url_weight.split("/")[-1]
            url_weight = url_weight.rstrip(".weights")
        
        if not is_test:
            suf = time.strftime("%Y.%m.%d_%H.%M")
            train_writer = tf.summary.FileWriter(sum_dir + "/" + url_weight + "_train", session.graph, filename_suffix=suf)
            valid_writer = tf.summary.FileWriter(sum_dir + "/" + url_weight + "_valid", session.graph, filename_suffix=suf)

        print("==> Loading dataset")
        dataset = utils.load_file(url_feature)
        if dataset:
            dataset = np.asarray(dataset, dtype=np.float32)
            lt = len(dataset)
            st = int(lt/2)
            lt = lt - st
            dataset = dataset[st:,:,:]
            train, valid = utils.process_data_grid(lt, p.batch_size, encoder_length, decoder_length, is_test)
            if attention_url:
                attention_data = utils.load_file(attention_url)
            else:
                attention_data = None
            model.set_data(dataset, train, valid, attention_data, session)
            if not is_test:
                best_val_epoch = 0
                best_val_loss = float('inf')
                # best_overall_val_loss = float('inf')
                print('==> starting training')
                for epoch in xrange(p.total_iteration):
                    print('Epoch {}'.format(epoch))
                    start = time.time()
                    train_loss, _ = model.run_epoch(session, train, epoch, train_writer, train_op=model.train_op, train=True)
                    print('Training loss: {}'.format(train_loss))

                    valid_loss, _ = model.run_epoch(session, valid, epoch, valid_writer)
                    print('Validation loss: {}'.format(valid_loss))

                    if valid_loss < best_val_loss:
                        best_val_loss = valid_loss
                        best_val_epoch = epoch
                        print('Saving weights')
                        saver.save(session, 'weights/%s.weights' % url_weight)

                    if (epoch - best_val_epoch) > p.early_stopping:
                        break
                    print('Total time: {}'.format(time.time() - start))
            else:
                # saver.restore(session, url_weight)
                print('==> running model')
                _, preds = model.run_epoch(session, model.train, shuffle=False)
                pt = re.compile("weights/([A-Za-z0-9_.]*).weights")
                name = pt.match(url_weight)
                if name:
                    name_s = name.group(1)
                else:
                    name_s = url_weight
                utils.save_file("test_sp/%s" % name_s, preds)