コード例 #1
0
ファイル: copy_task_v2.py プロジェクト: sznajder/ntm
def train(config):
    model = CopyModel(batch_size=config['batch_size'],
                      vector_dim=config['vector_dim'],
                      model_type=config['model_type'],
                      cell_params=config['cell_params'][config['model_type']])
    optimizer = tf.keras.optimizers.Adam(learning_rate=config['learning_rate'])
    sequence_loss_func = SequenceCrossEntropyLoss()
    for batch_index in range(config['num_batches']):
        seq_length = tf.constant(np.random.randint(
            1, config['max_seq_length'] + 1),
                                 dtype=tf.int32)
        x = generate_random_strings(config['batch_size'], seq_length,
                                    config['vector_dim'])
        with tf.GradientTape() as tape:
            y_pred = model((x, seq_length))
            loss = sequence_loss_func(y_true=x, y_pred=y_pred)
            loss = tf.reduce_mean(loss)
        grads = tape.gradient(loss, model.variables)
        optimizer.apply_gradients(grads_and_vars=zip(grads, model.variables))
        if batch_index % 100 == 0:
            x = generate_random_strings(config['batch_size'],
                                        config['test_seq_length'],
                                        config['vector_dim'])
            y_pred = model((x, config['test_seq_length']))
            loss = sequence_loss_func(y_true=x, y_pred=y_pred)
            print("batch %d: loss %f" % (batch_index, loss))
            print("original string sample: ", x[0])
            print("copied string sample: ", y_pred[0])
コード例 #2
0
def train(args):
    model_list = [NTMCopyModel(args, 1)]
    for seq_length in range(2, args.max_seq_length + 1):
        model_list.append(NTMCopyModel(args, seq_length, reuse=True))
    # model = NTM_model(args, args.max_seq_length)
    with tf.Session() as sess:
        if args.restore_training:
            saver = tf.train.Saver()
            ckpt = tf.train.get_checkpoint_state(args.save_dir + '/' +
                                                 args.model)
            saver.restore(sess, ckpt.model_checkpoint_path)
        else:
            saver = tf.train.Saver(tf.global_variables())
            tf.global_variables_initializer().run()
        train_writer = tf.summary.FileWriter(args.tensorboard_dir, sess.graph)
        plt.ion()
        plt.show()
        for b in range(args.num_epoches):
            seq_length = np.random.randint(1, args.max_seq_length + 1)
            model = model_list[seq_length - 1]
            # seq_length = args.max_seq_length
            x = generate_random_strings(args.batch_size, seq_length,
                                        args.vector_dim)
            feed_dict = {model.x: x}
            # print(sess.run([model.state_list, model.output_list], feed_dict=feed_dict))
            if b % 100 == 0:  # test
                p = 0  # select p th sample in the batch to show
                print(x[p, :, :])
                print(sess.run(model.o, feed_dict=feed_dict)[p, :, :])
                state_list = sess.run(model.state_list, feed_dict=feed_dict)
                if args.model == 'NTM':
                    w_plot = []
                    M_plot = np.concatenate(
                        [state['M'][p, :, :] for state in state_list])
                    for state in state_list:
                        w_plot.append(
                            np.concatenate([
                                state['w_list'][0][p, :],
                                state['w_list'][1][p, :]
                            ]))
                    plt.imshow(w_plot, interpolation='nearest', cmap='gray')
                    plt.draw()
                    plt.pause(0.001)
                copy_loss = sess.run(model.copy_loss, feed_dict=feed_dict)
                merged_summary = sess.run(model.copy_loss_summary,
                                          feed_dict=feed_dict)
                train_writer.add_summary(merged_summary, b)
                print('batches %d, loss %g' % (b, copy_loss))
            else:  # train
                sess.run(model.train_op, feed_dict=feed_dict)
            if b % 5000 == 0 and b > 0:
                saver.save(sess,
                           args.save_dir + '/' + args.model + '/model.tfmodel',
                           global_step=b)
コード例 #3
0
ファイル: copy_task_v2.py プロジェクト: safai-labs/ntm
def test(config, checkpoint_no):
    model = CopyModel(batch_size=config['batch_size'],
                      vector_dim=config['vector_dim'],
                      model_type=config['model_type'],
                      cell_params=config['cell_params'][config['model_type']])

    model.load_weights(config['save_dir'] + '_' + str(checkpoint_no))

    x = generate_random_strings(config['batch_size'],
                                config['test_seq_length'],
                                config['vector_dim'])
    y_pred = model((x, config['test_seq_length']))
    print("original string sample: ", x[0])
    print("copied string sample: ", y_pred[0])
コード例 #4
0
ファイル: copy_task.py プロジェクト: pukekaka/dogs
def test(args):
    model = NTMCopyModel(args, args.test_seq_length)
    saver = tf.train.Saver()
    ckpt = tf.train.get_checkpoint_state(args.save_dir)
    with tf.Session() as sess:
        saver.restore(sess, ckpt.model_checkpoint_path)
        x = generate_random_strings(args.batch_size, args.test_seq_length, args.vector_dim)
        feed_dict = {model.x: x}
        output, copy_loss, state_list = sess.run([model.o, model.copy_loss, model.state_list], feed_dict=feed_dict)
        for p in range(args.batch_size):
            print(x[p, :, :])
            print(output[p, :, :])
        print('copy_loss: %g' % copy_loss)
        if args.model == 'NTM':
            w_plot = []
            for state in state_list:
                w_plot.append(np.concatenate([state['w_list'][0][p, :], state['w_list'][1][p, :]]))
            plt.imshow(w_plot, interpolation='nearest', cmap='gray')
            plt.show()