def __init__(self, max_expressions=3, max_characters=50, num_examples=4, EOS=0, curriculum=True): # used for sampling programs self.max_expressions = max_expressions self.max_characters = max_characters self.num_examples = num_examples self.token_tables = op.build_token_tables() # special tokens self.SOS = len(self.token_tables.op_token_table) self.EOS = EOS # Reference for the current trajectory self.reference_prog = None self.examples = None self.user_prog = [] # Reference for curriculum learning self.correct = 0 self.curriculum = curriculum # Determine if use curriculum learning # attributes self.num_actions = self.SOS self.observation_space = None self.reward_space = [0, 1]
def train_supervised(args): ''' Parse arguments and build objects for supervised training approach ''' device = torch.device( 'cuda') if torch.cuda.is_available() else torch.device('cpu') token_tables = op.build_token_tables() from os import path train_logger = None if args.log_dir is not None: train_logger = tb.SummaryWriter(path.join(args.log_dir, 'train'), flush_secs=1) # init model robust_fill = RobustFill( string_size=len(op.CHARACTER), string_embedding_size=args.embedding_size, decoder_inp_size=128, hidden_size=args.hidden_size, program_size=len(token_tables.op_token_table), ) if args.continue_training: robust_fill.load_state_dict( torch.load( path.join(path.dirname(path.abspath(__file__)), args.checkpoint_filename))) robust_fill = robust_fill.to(device) robust_fill.set_device(device) if (args.optimizer == 'sgd'): optimizer = optim.SGD(robust_fill.parameters(), lr=args.lr) else: optimizer = optim.Adam(robust_fill.parameters(), lr=args.lr) train_dataset = RobustFillDataset(token_tables, d=args.number_progs) prog_dataloaer = DataLoader(dataset=train_dataset, batch_size=args.batch_size, shuffle=True, collate_fn=my_collate, num_workers=4) train_supervised_( args, robust_fill=robust_fill, optimizer=optimizer, dataloader=prog_dataloaer, train_logger=train_logger, checkpoint_filename=args.checkpoint_filename, checkpoint_step_size=args.checkpoint_step_size, checkpoint_print_tensors=args.print_tensors, )
def run_eval(args): ''' Constructs necessary data structures and parses args to call eval ''' token_tables = op.build_token_tables() model = RobustFill( string_size=len(op.CHARACTER), string_embedding_size=args.embedding_size, decoder_inp_size=args.embedding_size, hidden_size=args.hidden_size, program_size=len(token_tables.op_token_table), ) from os import path if args.continue_training: model.load_state_dict( torch.load(path.join(path.dirname(path.abspath(__file__)), args.checkpoint_filename), map_location=torch.device('cpu'))) eval(model, token_tables, num_samples=1000, beam_size=args.beam_size, em=(not args.consistency))
def train_sac(args): ''' Parse arguments and construct objects for training sac model ''' device = torch.device( 'cuda') if torch.cuda.is_available() else torch.device('cpu') token_tables = op.build_token_tables() # initialize tensorboard for logging output from os import path train_logger = None if args.log_dir is not None: train_logger = tb.SummaryWriter(path.join(args.log_dir, 'train'), flush_secs=1) # Load Models policy = RobustFill(string_size=len(op.CHARACTER), string_embedding_size=args.embedding_size, decoder_inp_size=128, hidden_size=args.hidden_size, program_size=len(token_tables.op_token_table), device=device) q_1 = SoftQNetwork(128, len(token_tables.op_token_table), args.hidden_size) q_2 = SoftQNetwork(128, len(token_tables.op_token_table), args.hidden_size) tgt_q_1 = SoftQNetwork(128, len(token_tables.op_token_table), args.hidden_size).eval() tgt_q_2 = SoftQNetwork(128, len(token_tables.op_token_table), args.hidden_size).eval() if args.continue_training_policy: policy.load_state_dict( torch.load(path.join(path.dirname(path.abspath(__file__)), args.checkpoint_filename), map_location=device)) elif args.continue_training: policy.load_state_dict( torch.load(path.join(path.dirname(path.abspath(__file__)), args.checkpoint_filename), map_location=device)) q_1.load_state_dict( torch.load(path.join(path.dirname(path.abspath(__file__)), args.q1_checkpoint_filename), map_location=device)) q_2.load_state_dict( torch.load(path.join(path.dirname(path.abspath(__file__)), args.q2_checkpoint_filename), map_location=device)) for target_param, param in zip(tgt_q_1.parameters(), q_1.parameters()): target_param.data.copy_(param.data) for target_param, param in zip(tgt_q_2.parameters(), q_2.parameters()): target_param.data.copy_(param.data) for param in tgt_q_1.parameters(): param.requires_grad = False for param in tgt_q_2.parameters(): param.requires_grad = False policy = policy.to(device) policy.set_device(device) q_1 = q_1.to(device) q_2 = q_2.to(device) tgt_q_1 = tgt_q_1.to(device) tgt_q_2 = tgt_q_2.to(device) # Initialize optimizers if (args.optimizer == 'sgd'): policy_opt = optim.SGD(policy.parameters(), lr=args.lr) q_1_opt = optim.SGD(q_1.parameters(), lr=args.lr) q_2_opt = optim.SGD(q_2.parameters(), lr=args.lr) entropy_opt = optim.SGD([policy.log_alpha], lr=args.lr) else: policy_opt = optim.Adam(policy.parameters(), lr=args.lr) q_1_opt = optim.Adam(q_1.parameters(), lr=args.lr) q_2_opt = optim.Adam(q_2.parameters(), lr=args.lr) entropy_opt = optim.Adam([policy.log_alpha], lr=args.lr) # Other necessary objects env = RobustFillEnv() replay_buffer_size = 1_000_000 replay_buffer = Replay_Buffer(replay_buffer_size, args.batch_size) her = HER() train_sac_(args, policy, q_1, q_2, tgt_q_1, tgt_q_2, policy_opt, q_1_opt, q_2_opt, entropy_opt, replay_buffer, her, env, train_logger, args.checkpoint_filename, args.checkpoint_step_size, args.print_tensors)
def train_reinforce(args): ''' Parse arguments and construct objects for training reinforce model, with no baseine ''' device = torch.device( 'cuda') if torch.cuda.is_available() else torch.device('cpu') token_tables = op.build_token_tables() # initialize tensorboard for logging output from os import path train_logger = None if args.log_dir is not None: train_logger = tb.SummaryWriter(path.join(args.log_dir, 'train'), flush_secs=1) # Load Models policy = RobustFill(string_size=len(op.CHARACTER), string_embedding_size=args.embedding_size, decoder_inp_size=args.embedding_size, hidden_size=args.hidden_size, program_size=len(token_tables.op_token_table), device=device) value = ValueNetwork(args.embedding_size, args.hidden_size).to(device) if args.continue_training_policy: policy.load_state_dict( torch.load(path.join(path.dirname(path.abspath(__file__)), args.checkpoint_filename), map_location=device)) elif args.continue_training: policy.load_state_dict( torch.load(path.join(path.dirname(path.abspath(__file__)), args.checkpoint_filename), map_location=device)) value.load_state_dict( torch.load(path.join(path.dirname(path.abspath(__file__)), args.val_checkpoint_filename), map_location=device)) policy = policy.to(device) value = value.to(device) # Initialize Optimizer if (args.optimizer == 'sgd'): pol_opt = optim.SGD(policy.parameters(), lr=args.lr) val_opt = optim.SGD(value.parameters(), lr=args.lr) else: pol_opt = optim.Adam(policy.parameters(), lr=args.lr) val_opt = optim.Adam(value.parameters(), lr=args.lr) # Load Environment env = RobustFillEnv() train_reinforce_( args, policy=policy, value=value, pol_opt=pol_opt, value_opt=val_opt, env=env, train_logger=train_logger, checkpoint_filename=args.checkpoint_filename, checkpoint_step_size=args.checkpoint_step_size, checkpoint_print_tensors=args.print_tensors, )
def train_supervised_(args, robust_fill, optimizer, dataloader, train_logger, checkpoint_filename, checkpoint_step_size, checkpoint_print_tensors): ''' Classic training loop for supervised algorithm ''' token_tables = op.build_token_tables() device = robust_fill.device global_iter = 0 # No number of iterartions here - just train for a real long time while True: for b in dataloader: optimizer.zero_grad() expected_programs, examples = b max_length = max_program_length(expected_programs) # teacher learning padded_tgt_programs = torch.LongTensor([[ program[i] if i < len(program) else 0 for i in range(max_length) ] for program in expected_programs]).to(device) # Output: program_size x b x #ops, need to turn b x #ops x #p_size actual_programs = robust_fill(examples, padded_tgt_programs).permute( 1, 2, 0) padding_index = -1 padded_expected_programs = torch.LongTensor([[ program[i] if i < len(program) else padding_index for i in range(max_length) ] for program in expected_programs]).to(device) loss = F.cross_entropy(actual_programs, padded_expected_programs, ignore_index=padding_index) loss.backward() if args.grad_clip > 0.: torch.nn.utils.clip_grad_norm_(robust_fill.parameters(), args.grad_clip) optimizer.step() # Debugging information if train_logger is not None: train_logger.add_scalar('loss', loss.item(), global_iter) if global_iter % checkpoint_step_size == 0: print('Checkpointing at batch {}'.format(global_iter)) print('Loss: {}'.format(loss.item())) # note this code will not print correct if more than 1 printed if checkpoint_print_tensors: temp = actual_programs.permute( 2, 0, 1)[:len(expected_programs[0]), :1, :] tokens = torch.argmax(temp.permute(1, 0, 2), dim=-1) tokens = tokens[0].tolist() print_programs(expected_programs[0], tokens, train_logger, token_tables.token_op_table, global_iter) if checkpoint_filename is not None: print('Saving to file {}'.format(checkpoint_filename)) torch.save(robust_fill.state_dict(), checkpoint_filename) print('Done checkpointing model') global_iter += 1