def train_loop(args, train_loader, val_loader): vocab = utils.load_vocab(args.vocab_json) program_generator, pg_kwargs, pg_optimizer = None, None, None execution_engine, ee_kwargs, ee_optimizer = None, None, None baseline_model, baseline_kwargs, baseline_optimizer = None, None, None baseline_type = None pg_best_state, ee_best_state, baseline_best_state = None, None, None checkpoint = ModelCheckpoint(args.checkpoint_path, monitor='loss', verbose=1, save_best_only=True, mode='min', load_weights_on_restart=True) # Set up model if args.model_type == 'PG' or args.model_type == 'PG+EE': program_generator, pg_kwargs = get_program_generator(args) pg_optimizer = optimizers.Adam(args.learning_rate) print('Here is the program generator:') # program_generator.build(input_shape=[46,]) # program_generator.compile(optimizer='adam', loss='mse') # i print(program_generator.summary()) if args.model_type == 'EE' or args.model_type == 'PG+EE': execution_engine, ee_kwargs = get_execution_engine(args) ee_optimizer = optimizers.Adam(args.learning_rate) print('Here is the execution engine:') print(execution_engine) stats = { 'train_losses': [], 'train_rewards': [], 'train_losses_ts': [], 'train_accs': [], 'val_accs': [], 'val_accs_ts': [], 'best_val_acc': -1, 'model_t': 0, } t, epoch, reward_moving_average = 0, 0, 0 batch_size = 64 checkpoint_dir = './training_checkpoints' checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt") checkpoint = tf.train.Checkpoint(optimizer=pg_optimizer, program_generator=program_generator) # set_mode('train', [program_generator, execution_engine, baseline_model]) print('train_loader has %d samples' % len(train_loader)) # train_loader = train_loader[:256] print('train_loader has %d samples' % len(train_loader)) print('val_loader has %d samples' % len(val_loader)) # data_sampler = iter(range(len(train_loader))) data_load = batch_creater(train_loader, batch_size, False) print("Data load length :", len(data_load)) # print(data_load[0][0]) while t < args.num_iterations: total_loss = 0 epoch += 1 print('Starting epoch %d' % epoch) print("value of t :", t) # train_loader_data = get_data(train_loader) # print("train data loader length :", len(train_loader_data)) # print(train_loader[0].shape) # print(train_loader[0]) for run_num, batch in enumerate(data_load): batch_loss = 0 with tf.GradientTape() as tape: t += 1 questions, _, feats, answers, programs, _ = to_tensor( batch[0]), batch[1], to_tensor(batch[2]), to_tensor( batch[3]), to_tensor(batch[4]), batch[5] #print("Questions : ", questions.shape) #print("Features :", feats.shape) #print(" Answers : ", answers.shape) #print(" prgrams : ", programs.shape) print("----------------") questions_var = tf.Variable(questions) feats_var = tf.Variable(feats) answers_var = tf.Variable(answers) if programs[0] is not None: programs_var = tf.Variable(programs) reward = None if args.model_type == 'PG': #checkpoint_dir = './training_checkpoints' #checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt") # checkpoint = tf.train.Checkpoint(optimizer=pg_optimizer, # program_generator=program_generator) # Train program generator with ground-truth programs+++ batch_loss = program_generator(questions_var, programs_var) total_loss += batch_loss variables = program_generator.variables gradients = tape.gradient(batch_loss, variables) pg_optimizer.apply_gradients(zip(gradients), variables) print('Epoch {} Batch No. {} Loss {:.4f}'.format( epoch, run_num, batch_loss.numpy())) if epoch % 2 == 0: checkpoint.save(file_prefix=checkpoint_prefix) if t == args.num_iterations: break # program_generator.compile(optimizer=pg_optimizer, loss=loss) # ques = np.asarray(questions_var.read_value()) # prog = np.asarray(programs_var.read_value()) # history = program_generator.fit( # x=ques, # y=prog, # batch_size=args.batch_size, # epochs=10, # verbose=0, # callbacks=[LossAndErrorPrintingCallback(), checkpoint]) # elif args.model_type == 'EE': # # Train execution engine with ground-truth programs # scores = execution_engine(feats_var, programs_var) # loss = tf.nn.softmax_cross_entropy_with_logits( # scores, answers_var) # execution_engine.compile(optimizer=ee_optimizer, loss=loss) # history = execution_engine.fit( # questions_var, # to_categorical(answers_var), # batch_size=args.batch_size, # epochs=10, # verbose=0, # callbacks=[LossAndErrorPrintingCallback(), checkpoint]) # elif args.model_type == 'PG+EE': # programs_pred = program_generator.reinforce_sample(questions_var) # scores = execution_engine(feats_var, programs_pred) # # loss = tf.nn.softmax_cross_entropy_with_logits(scores, answers_var) # _, preds = scores.data.max(1) # # raw_reward = (preds == answers).float() # raw_reward = tf.cast((preds == answers), dtype=tf.float32) # reward_moving_average *= args.reward_decay # reward_moving_average += (1.0 - args.reward_decay) * raw_reward.mean() # centered_reward = raw_reward - reward_moving_average # # if args.train_execution_engine == 1: # ee_optimizer.zero_grad() # loss.backward() # ee_optimizer.step() # # if args.train_program_generator == 1: # pg_optimizer.zero_grad() # program_generator.reinforce_backward(centered_reward.cuda()) # pg_optimizer.step() # if t % args.record_loss_every == 0: # print(t, loss.data[0]) # stats['train_losses'].append(loss.data[0]) # stats['train_losses_ts'].append(t) # if reward is not None: # stats['train_rewards'].append(reward) # # if t % args.checkpoint_every == 0: # print('Checking training accuracy ... ') # train_acc = check_accuracy(args, program_generator, execution_engine, # baseline_model, train_loader) # print('train accuracy is', train_acc) # print('Checking validation accuracy ...') # val_acc = check_accuracy(args, program_generator, execution_engine, # baseline_model, val_loader) # print('val accuracy is ', val_acc) # stats['train_accs'].append(train_acc) # stats['val_accs'].append(val_acc) # stats['val_accs_ts'].append(t) # # if val_acc > stats['best_val_acc']: # stats['best_val_acc'] = val_acc # stats['model_t'] = t # best_pg_state = get_state(program_generator) # best_ee_state = get_state(execution_engine) # best_baseline_state = get_state(baseline_model) # # checkpoint = { # 'args': args.__dict__, # 'program_generator_kwargs': pg_kwargs, # 'program_generator_state': best_pg_state, # 'execution_engine_kwargs': ee_kwargs, # 'execution_engine_state': best_ee_state, # 'baseline_kwargs': baseline_kwargs, # 'baseline_state': best_baseline_state, # 'baseline_type': baseline_type, # 'vocab': vocab # } # for k, v in stats.items(): # checkpoint[k] = v # print('Saving checkpoint to %s' % args.checkpoint_path) # torch.save(checkpoint, args.checkpoint_path) # del checkpoint['program_generator_state'] # del checkpoint['execution_engine_state'] # del checkpoint['baseline_state'] # with open(args.checkpoint_path + '.json', 'w') as f: # json.dump(checkpoint, f) if t == args.num_iterations: break