예제 #1
0
def train_loop(args, train_loader, val_loader):
    vocab = utils.load_vocab(args.vocab_json)
    program_generator, pg_kwargs, pg_optimizer = None, None, None
    execution_engine, ee_kwargs, ee_optimizer = None, None, None
    baseline_model, baseline_kwargs, baseline_optimizer = None, None, None
    baseline_type = None

    pg_best_state, ee_best_state, baseline_best_state = None, None, None

    checkpoint = ModelCheckpoint(args.checkpoint_path,
                                 monitor='loss',
                                 verbose=1,
                                 save_best_only=True,
                                 mode='min',
                                 load_weights_on_restart=True)
    # Set up model
    if args.model_type == 'PG' or args.model_type == 'PG+EE':
        program_generator, pg_kwargs = get_program_generator(args)
        pg_optimizer = optimizers.Adam(args.learning_rate)
        print('Here is the program generator:')
        # program_generator.build(input_shape=[46,])
        # program_generator.compile(optimizer='adam', loss='mse')
        # i print(program_generator.summary())
    if args.model_type == 'EE' or args.model_type == 'PG+EE':
        execution_engine, ee_kwargs = get_execution_engine(args)
        ee_optimizer = optimizers.Adam(args.learning_rate)
        print('Here is the execution engine:')
        print(execution_engine)

    stats = {
        'train_losses': [],
        'train_rewards': [],
        'train_losses_ts': [],
        'train_accs': [],
        'val_accs': [],
        'val_accs_ts': [],
        'best_val_acc': -1,
        'model_t': 0,
    }
    t, epoch, reward_moving_average = 0, 0, 0
    batch_size = 64
    checkpoint_dir = './training_checkpoints'
    checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
    checkpoint = tf.train.Checkpoint(optimizer=pg_optimizer,
                                     program_generator=program_generator)
    # set_mode('train', [program_generator, execution_engine, baseline_model])

    print('train_loader has %d samples' % len(train_loader))
    # train_loader = train_loader[:256]
    print('train_loader has %d samples' % len(train_loader))
    print('val_loader has %d samples' % len(val_loader))
    # data_sampler = iter(range(len(train_loader)))
    data_load = batch_creater(train_loader, batch_size, False)
    print("Data load length :", len(data_load))
    # print(data_load[0][0])

    while t < args.num_iterations:
        total_loss = 0
        epoch += 1
        print('Starting epoch %d' % epoch)
        print("value of t :", t)
        # train_loader_data = get_data(train_loader)
        # print("train data loader length :", len(train_loader_data))
        # print(train_loader[0].shape)
        # print(train_loader[0])
        for run_num, batch in enumerate(data_load):
            batch_loss = 0
            with tf.GradientTape() as tape:
                t += 1
                questions, _, feats, answers, programs, _ = to_tensor(
                    batch[0]), batch[1], to_tensor(batch[2]), to_tensor(
                        batch[3]), to_tensor(batch[4]), batch[5]

                #print("Questions : ", questions.shape)
                #print("Features :", feats.shape)
                #print(" Answers : ", answers.shape)
                #print(" prgrams : ", programs.shape)
                print("----------------")

                questions_var = tf.Variable(questions)
                feats_var = tf.Variable(feats)
                answers_var = tf.Variable(answers)
                if programs[0] is not None:
                    programs_var = tf.Variable(programs)

                reward = None
                if args.model_type == 'PG':
                    #checkpoint_dir = './training_checkpoints'
                    #checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
                    # checkpoint = tf.train.Checkpoint(optimizer=pg_optimizer,
                    #                                 program_generator=program_generator)
                    # Train program generator with ground-truth programs+++
                    batch_loss = program_generator(questions_var, programs_var)
            total_loss += batch_loss
            variables = program_generator.variables
            gradients = tape.gradient(batch_loss, variables)
            pg_optimizer.apply_gradients(zip(gradients), variables)

            print('Epoch {} Batch No. {} Loss {:.4f}'.format(
                epoch, run_num, batch_loss.numpy()))
        if epoch % 2 == 0:
            checkpoint.save(file_prefix=checkpoint_prefix)
        if t == args.num_iterations:
            break
            # program_generator.compile(optimizer=pg_optimizer, loss=loss)
            # ques = np.asarray(questions_var.read_value())
            # prog = np.asarray(programs_var.read_value())
            # history = program_generator.fit(
            #     x=ques,
            #     y=prog,
            #     batch_size=args.batch_size,
            #     epochs=10,
            #     verbose=0,
            #     callbacks=[LossAndErrorPrintingCallback(), checkpoint])

            # elif args.model_type == 'EE':
            #     # Train execution engine with ground-truth programs
            #     scores = execution_engine(feats_var, programs_var)
            #     loss = tf.nn.softmax_cross_entropy_with_logits(
            #         scores, answers_var)
            #     execution_engine.compile(optimizer=ee_optimizer, loss=loss)
            #     history = execution_engine.fit(
            #         questions_var,
            #         to_categorical(answers_var),
            #         batch_size=args.batch_size,
            #         epochs=10,
            #         verbose=0,
            #         callbacks=[LossAndErrorPrintingCallback(), checkpoint])

            # elif args.model_type == 'PG+EE':
            #     programs_pred = program_generator.reinforce_sample(questions_var)
            #     scores = execution_engine(feats_var, programs_pred)
            #
            #     loss = tf.nn.softmax_cross_entropy_with_logits(scores, answers_var)
            #     _, preds = scores.data.max(1)
            #     # raw_reward = (preds == answers).float()
            #     raw_reward = tf.cast((preds == answers), dtype=tf.float32)
            #     reward_moving_average *= args.reward_decay
            #     reward_moving_average += (1.0 - args.reward_decay) * raw_reward.mean()
            #     centered_reward = raw_reward - reward_moving_average
            #
            #     if args.train_execution_engine == 1:
            #         ee_optimizer.zero_grad()
            #         loss.backward()
            #         ee_optimizer.step()
            #
            #     if args.train_program_generator == 1:
            #         pg_optimizer.zero_grad()
            #         program_generator.reinforce_backward(centered_reward.cuda())
            #         pg_optimizer.step()

            # if t % args.record_loss_every == 0:
            #     print(t, loss.data[0])
            #     stats['train_losses'].append(loss.data[0])
            #     stats['train_losses_ts'].append(t)
            #     if reward is not None:
            #         stats['train_rewards'].append(reward)
            #
            # if t % args.checkpoint_every == 0:
            #     print('Checking training accuracy ... ')
            #     train_acc = check_accuracy(args, program_generator, execution_engine,
            #                                baseline_model, train_loader)
            #     print('train accuracy is', train_acc)
            #     print('Checking validation accuracy ...')
            #     val_acc = check_accuracy(args, program_generator, execution_engine,
            #                              baseline_model, val_loader)
            #     print('val accuracy is ', val_acc)
            #     stats['train_accs'].append(train_acc)
            #     stats['val_accs'].append(val_acc)
            #     stats['val_accs_ts'].append(t)
            #
            #     if val_acc > stats['best_val_acc']:
            #         stats['best_val_acc'] = val_acc
            #         stats['model_t'] = t
            #         best_pg_state = get_state(program_generator)
            #         best_ee_state = get_state(execution_engine)
            #         best_baseline_state = get_state(baseline_model)
            #
            #     checkpoint = {
            #         'args': args.__dict__,
            #         'program_generator_kwargs': pg_kwargs,
            #         'program_generator_state': best_pg_state,
            #         'execution_engine_kwargs': ee_kwargs,
            #         'execution_engine_state': best_ee_state,
            #         'baseline_kwargs': baseline_kwargs,
            #         'baseline_state': best_baseline_state,
            #         'baseline_type': baseline_type,
            #         'vocab': vocab
            #     }
            #     for k, v in stats.items():
            #         checkpoint[k] = v
            #     print('Saving checkpoint to %s' % args.checkpoint_path)
            #     torch.save(checkpoint, args.checkpoint_path)
            #     del checkpoint['program_generator_state']
            #     del checkpoint['execution_engine_state']
            #     del checkpoint['baseline_state']
            #     with open(args.checkpoint_path + '.json', 'w') as f:
            #         json.dump(checkpoint, f)

        if t == args.num_iterations:
            break