for name, param in model.named_parameters():
        print(name, param.shape, param.device, param.requires_grad)

    max_step = config['model']['max_step']
    check_step = config['model']['check_step']
    batch_size = config['model']['batch_size']
    model.zero_grad()
    train_slot_loss, train_intent_loss = 0, 0
    best_val_f1 = 0.

    writer.add_text('config', json.dumps(config))

    for step in range(1, max_step + 1):
        model.train()
        batched_data = dataloader.get_train_batch(batch_size)
        batched_data = tuple(t.to(DEVICE) for t in batched_data)
        word_seq_tensor, tag_seq_tensor, intent_tensor, word_mask_tensor, tag_mask_tensor, context_seq_tensor, context_mask_tensor = batched_data
        if not config['model']['context']:
            context_seq_tensor, context_mask_tensor = None, None
        _, _, slot_loss, intent_loss = model.forward(
            word_seq_tensor, word_mask_tensor, tag_seq_tensor, tag_mask_tensor,
            intent_tensor, context_seq_tensor, context_mask_tensor)
        train_slot_loss += slot_loss.item()
        train_intent_loss += intent_loss.item()
        loss = slot_loss + intent_loss
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        if config['model']['finetune']:
            scheduler.step()  # Update learning rate schedule
Пример #2
0
def train():
    dataloader = Dataloader()
    dataloader.split(0.1)
    val_data = dataloader.get_val()
    dev_st_batch = [d[0] for d in val_data]
    dev_at_batch = [d[1] for d in val_data]
    dev_st1_batch = [d[2] for d in val_data]
    dev_rt_batch = [d[3] for d in val_data]
    dev_terminal_batch = [d[4] for d in val_data]
    global dqn, sess
    with tf.Graph().as_default():

        sess = tf.Session(config=config)
        with sess.as_default():
            # check if a saved model
            # meta_filename = get_meta_filename(False, FLAGS.model_dir)
            # print("meta filename = ", meta_filename)
            # if meta_filename:
            # saver = recover_model(meta_filename)
            # saver.restore(sess, tf.train.latest_checkpoint(FLAGS.model_dir))
            # print("restoring the model...")
            dqn = DQN(goal_size=dataloader.goal_size,
                      act_size=dataloader.action_size)

            # check whether checkpoint exist if yes, load the checkpoint, if not
            # , initial the variables
            ckpt = tf.train.get_checkpoint_state(FLAGS.model_dir)
            if ckpt:
                print("Reading ckpt model from %s" %
                      ckpt.model_checkpoint_path)
                dqn.saver.restore(sess, ckpt.model_checkpoint_path)
            else:
                print("Create model with fresh parameters")
                # Initialize all variables
                sess.run(tf.global_variables_initializer())

            for i in range(FLAGS.epoch):
                train_batch = dataloader.get_train_batch(FLAGS.batch_size)
                st_batch = [d[0] for d in train_batch]
                at_batch = [d[1] for d in train_batch]
                st1_batch = [d[2] for d in train_batch]
                rt_batch = [d[3] for d in train_batch]
                terminal_batch = [d[4] for d in train_batch]
                target_q_batch, _ = max_q(st_batch, st1_batch)
                train_step(st_batch, at_batch, target_q_batch, rt_batch,
                           terminal_batch)
                current_step = tf.train.global_step(sess, dqn.global_step)
                if current_step % FLAGS.evaluate_every == 0:
                    dev_target_q_batch, dev_q_action_batch = max_q(
                        dev_at_batch, dev_st1_batch)
                    dev_step(dev_st_batch, dev_at_batch, dev_target_q_batch,
                             dev_rt_batch, dev_terminal_batch)
                    idx = randint(0, len(dev_st_batch))
                    print("State: " +
                          str(dataloader.mapState(dev_st1_batch[idx])))
                    print("Action: " +
                          str(dataloader.mapAction(dev_q_action_batch[idx])))
                #saver = tf.train.Saver(max_to_keep=4, keep_checkpoint_every_n_hours=2)
                if current_step % 1000 == 0:
                    #_, q_action = max_q()
                    saver = dqn.saver
                    saver.save(sess,
                               FLAGS.model_dir + '/test_model',
                               global_step=dqn.global_step)
Пример #3
0
def train():
    """ just training """

    # load data
    print("loading test data")
    dataloader = Dataloader(args.data_dir, args.batch_size)
    test_x, test_y, angles = dataloader.get_test_data()
    n_tests = test_x.shape[0]
    n_samples = 1800
    n_batches = n_samples // args.batch_size
    print('total number of samples: {}'.format(n_samples))
    print('total number of steps: {}'.format(n_batches * args.epochs))

    config = tf.ConfigProto(device_count={'GPU': 1}, allow_soft_placement=True)
    config.gpu_options.allow_growth = True

    with tf.Session(config=config) as sess:
        model = create_model(sess)

        log_every_n_batches = 10
        start_time = time.time()
        # GO !!
        for e in xrange(args.epochs):
            print('working on epoch {0}/{1}'.format(e + 1, args.epochs))
            epoch_start_time = time.time()
            epoch_loss, batch_loss = 0, 0
            dataloader.flesh_batch_order()

            for i in xrange(n_batches):
                if (i + 1) % log_every_n_batches == 0:
                    print('working on epoch {0}, batch {1}/{2}'.format(
                        e + 1, i + 1, n_batches))
                enc_in, dec_out = dataloader.get_train_batch(i)
                _, _, step_loss, summary = model.step(sess, enc_in, dec_out,
                                                      True)
                epoch_loss += step_loss
                batch_loss += step_loss
                model.train_writer.add_summary(summary,
                                               model.global_step.eval())
                if (i + 1) % log_every_n_batches == 0:
                    print('current batch loss: {:.2f}'.format(
                        batch_loss / log_every_n_batches))
                    batch_loss = 0

            epoch_time = time.time() - epoch_start_time
            print('epoch {0}/{1} finish in {2:.2f} s'.format(
                e + 1, args.epochs, epoch_time))
            print('average epoch loss: {:.4f}'.format(epoch_loss / n_batches))

            print('saving model...')
            model.saver.save(sess, os.path.join(args.output_dir, 'ckpt'),
                             model.global_step.eval())

            # test after each epoch
            loss, pos_err = 0, 0
            for j in range(n_tests):
                enc_in, dec_out = np.expand_dims(test_x[j], 0), np.expand_dims(
                    test_y[j], 0)  # input must be [?, 32, 32, 500]
                pos, curv, step_loss = model.step(sess, enc_in, dec_out, False)
                step_pos_err = evaluate_pos(pos[0], test_y[j, ..., 100:400],
                                            test_y[j, ..., :100])
                loss += step_loss
                pos_err += step_pos_err
            avg_pos_err = pos_err / n_tests
            err_summary = sess.run(model.err_m_summary,
                                   {model.err_m: avg_pos_err})
            model.test_writer.add_summary(err_summary,
                                          model.global_step.eval())

            print('=================================\n'
                  'total loss avg:            %.4f\n'
                  'position error avg(m):     %.4f\n'
                  '=================================' %
                  (loss / n_tests, avg_pos_err))
            #pos = reconstruction(pos, curv)
            #avg_pos_err = evaluate_pos(pos, test_y[..., 100:400], test_y[..., :100])
            #print('position error after reconstruction: %.4e' % avg_pos_err)

        print('training finish in {:.2f} s'.format(time.time() - start_time))