示例#1
0
    def __init__(self,
                 max_expressions=3,
                 max_characters=50,
                 num_examples=4,
                 EOS=0,
                 curriculum=True):

        # used for sampling programs
        self.max_expressions = max_expressions
        self.max_characters = max_characters
        self.num_examples = num_examples
        self.token_tables = op.build_token_tables()

        # special tokens
        self.SOS = len(self.token_tables.op_token_table)
        self.EOS = EOS

        # Reference for the current trajectory
        self.reference_prog = None
        self.examples = None
        self.user_prog = []

        # Reference for curriculum learning
        self.correct = 0
        self.curriculum = curriculum  # Determine if use curriculum learning

        # attributes
        self.num_actions = self.SOS
        self.observation_space = None
        self.reward_space = [0, 1]
示例#2
0
def train_supervised(args):
    '''
    Parse arguments and build objects for supervised training approach
    '''
    device = torch.device(
        'cuda') if torch.cuda.is_available() else torch.device('cpu')
    token_tables = op.build_token_tables()
    from os import path
    train_logger = None
    if args.log_dir is not None:
        train_logger = tb.SummaryWriter(path.join(args.log_dir, 'train'),
                                        flush_secs=1)

    # init model
    robust_fill = RobustFill(
        string_size=len(op.CHARACTER),
        string_embedding_size=args.embedding_size,
        decoder_inp_size=128,
        hidden_size=args.hidden_size,
        program_size=len(token_tables.op_token_table),
    )
    if args.continue_training:
        robust_fill.load_state_dict(
            torch.load(
                path.join(path.dirname(path.abspath(__file__)),
                          args.checkpoint_filename)))
    robust_fill = robust_fill.to(device)
    robust_fill.set_device(device)

    if (args.optimizer == 'sgd'):
        optimizer = optim.SGD(robust_fill.parameters(), lr=args.lr)
    else:
        optimizer = optim.Adam(robust_fill.parameters(), lr=args.lr)

    train_dataset = RobustFillDataset(token_tables, d=args.number_progs)
    prog_dataloaer = DataLoader(dataset=train_dataset,
                                batch_size=args.batch_size,
                                shuffle=True,
                                collate_fn=my_collate,
                                num_workers=4)

    train_supervised_(
        args,
        robust_fill=robust_fill,
        optimizer=optimizer,
        dataloader=prog_dataloaer,
        train_logger=train_logger,
        checkpoint_filename=args.checkpoint_filename,
        checkpoint_step_size=args.checkpoint_step_size,
        checkpoint_print_tensors=args.print_tensors,
    )
示例#3
0
def run_eval(args):
    '''
    Constructs necessary data structures and parses args to call eval
    '''
    token_tables = op.build_token_tables()
    model = RobustFill(
        string_size=len(op.CHARACTER),
        string_embedding_size=args.embedding_size,
        decoder_inp_size=args.embedding_size,
        hidden_size=args.hidden_size,
        program_size=len(token_tables.op_token_table),
    )
    from os import path
    if args.continue_training:
        model.load_state_dict(
            torch.load(path.join(path.dirname(path.abspath(__file__)),
                                 args.checkpoint_filename),
                       map_location=torch.device('cpu')))

    eval(model,
         token_tables,
         num_samples=1000,
         beam_size=args.beam_size,
         em=(not args.consistency))
示例#4
0
def train_sac(args):
    '''
    Parse arguments and construct objects for training sac model
    '''
    device = torch.device(
        'cuda') if torch.cuda.is_available() else torch.device('cpu')
    token_tables = op.build_token_tables()
    # initialize tensorboard for logging output
    from os import path
    train_logger = None
    if args.log_dir is not None:
        train_logger = tb.SummaryWriter(path.join(args.log_dir, 'train'),
                                        flush_secs=1)

    # Load Models
    policy = RobustFill(string_size=len(op.CHARACTER),
                        string_embedding_size=args.embedding_size,
                        decoder_inp_size=128,
                        hidden_size=args.hidden_size,
                        program_size=len(token_tables.op_token_table),
                        device=device)
    q_1 = SoftQNetwork(128, len(token_tables.op_token_table), args.hidden_size)
    q_2 = SoftQNetwork(128, len(token_tables.op_token_table), args.hidden_size)

    tgt_q_1 = SoftQNetwork(128, len(token_tables.op_token_table),
                           args.hidden_size).eval()
    tgt_q_2 = SoftQNetwork(128, len(token_tables.op_token_table),
                           args.hidden_size).eval()

    if args.continue_training_policy:
        policy.load_state_dict(
            torch.load(path.join(path.dirname(path.abspath(__file__)),
                                 args.checkpoint_filename),
                       map_location=device))
    elif args.continue_training:
        policy.load_state_dict(
            torch.load(path.join(path.dirname(path.abspath(__file__)),
                                 args.checkpoint_filename),
                       map_location=device))
        q_1.load_state_dict(
            torch.load(path.join(path.dirname(path.abspath(__file__)),
                                 args.q1_checkpoint_filename),
                       map_location=device))
        q_2.load_state_dict(
            torch.load(path.join(path.dirname(path.abspath(__file__)),
                                 args.q2_checkpoint_filename),
                       map_location=device))

    for target_param, param in zip(tgt_q_1.parameters(), q_1.parameters()):
        target_param.data.copy_(param.data)
    for target_param, param in zip(tgt_q_2.parameters(), q_2.parameters()):
        target_param.data.copy_(param.data)
    for param in tgt_q_1.parameters():
        param.requires_grad = False
    for param in tgt_q_2.parameters():
        param.requires_grad = False

    policy = policy.to(device)
    policy.set_device(device)
    q_1 = q_1.to(device)
    q_2 = q_2.to(device)
    tgt_q_1 = tgt_q_1.to(device)
    tgt_q_2 = tgt_q_2.to(device)

    # Initialize optimizers
    if (args.optimizer == 'sgd'):
        policy_opt = optim.SGD(policy.parameters(), lr=args.lr)
        q_1_opt = optim.SGD(q_1.parameters(), lr=args.lr)
        q_2_opt = optim.SGD(q_2.parameters(), lr=args.lr)
        entropy_opt = optim.SGD([policy.log_alpha], lr=args.lr)
    else:
        policy_opt = optim.Adam(policy.parameters(), lr=args.lr)
        q_1_opt = optim.Adam(q_1.parameters(), lr=args.lr)
        q_2_opt = optim.Adam(q_2.parameters(), lr=args.lr)
        entropy_opt = optim.Adam([policy.log_alpha], lr=args.lr)

    # Other necessary objects
    env = RobustFillEnv()
    replay_buffer_size = 1_000_000
    replay_buffer = Replay_Buffer(replay_buffer_size, args.batch_size)
    her = HER()

    train_sac_(args, policy, q_1, q_2, tgt_q_1, tgt_q_2, policy_opt, q_1_opt,
               q_2_opt, entropy_opt, replay_buffer, her, env, train_logger,
               args.checkpoint_filename, args.checkpoint_step_size,
               args.print_tensors)
示例#5
0
def train_reinforce(args):
    '''
    Parse arguments and construct objects for training reinforce model, with no baseine
    '''
    device = torch.device(
        'cuda') if torch.cuda.is_available() else torch.device('cpu')
    token_tables = op.build_token_tables()

    # initialize tensorboard for logging output
    from os import path
    train_logger = None
    if args.log_dir is not None:
        train_logger = tb.SummaryWriter(path.join(args.log_dir, 'train'),
                                        flush_secs=1)

    # Load Models
    policy = RobustFill(string_size=len(op.CHARACTER),
                        string_embedding_size=args.embedding_size,
                        decoder_inp_size=args.embedding_size,
                        hidden_size=args.hidden_size,
                        program_size=len(token_tables.op_token_table),
                        device=device)
    value = ValueNetwork(args.embedding_size, args.hidden_size).to(device)
    if args.continue_training_policy:
        policy.load_state_dict(
            torch.load(path.join(path.dirname(path.abspath(__file__)),
                                 args.checkpoint_filename),
                       map_location=device))
    elif args.continue_training:
        policy.load_state_dict(
            torch.load(path.join(path.dirname(path.abspath(__file__)),
                                 args.checkpoint_filename),
                       map_location=device))
        value.load_state_dict(
            torch.load(path.join(path.dirname(path.abspath(__file__)),
                                 args.val_checkpoint_filename),
                       map_location=device))
    policy = policy.to(device)
    value = value.to(device)
    # Initialize Optimizer
    if (args.optimizer == 'sgd'):
        pol_opt = optim.SGD(policy.parameters(), lr=args.lr)
        val_opt = optim.SGD(value.parameters(), lr=args.lr)
    else:
        pol_opt = optim.Adam(policy.parameters(), lr=args.lr)
        val_opt = optim.Adam(value.parameters(), lr=args.lr)

    # Load Environment
    env = RobustFillEnv()
    train_reinforce_(
        args,
        policy=policy,
        value=value,
        pol_opt=pol_opt,
        value_opt=val_opt,
        env=env,
        train_logger=train_logger,
        checkpoint_filename=args.checkpoint_filename,
        checkpoint_step_size=args.checkpoint_step_size,
        checkpoint_print_tensors=args.print_tensors,
    )
示例#6
0
def train_supervised_(args, robust_fill, optimizer, dataloader, train_logger,
                      checkpoint_filename, checkpoint_step_size,
                      checkpoint_print_tensors):
    '''
    Classic training loop for supervised algorithm
    '''
    token_tables = op.build_token_tables()
    device = robust_fill.device
    global_iter = 0
    # No number of iterartions here - just train for a real long time
    while True:
        for b in dataloader:

            optimizer.zero_grad()

            expected_programs, examples = b
            max_length = max_program_length(expected_programs)

            # teacher learning
            padded_tgt_programs = torch.LongTensor([[
                program[i] if i < len(program) else 0
                for i in range(max_length)
            ] for program in expected_programs]).to(device)

            # Output: program_size x b x #ops, need to turn b x #ops x #p_size
            actual_programs = robust_fill(examples,
                                          padded_tgt_programs).permute(
                                              1, 2, 0)
            padding_index = -1
            padded_expected_programs = torch.LongTensor([[
                program[i] if i < len(program) else padding_index
                for i in range(max_length)
            ] for program in expected_programs]).to(device)

            loss = F.cross_entropy(actual_programs,
                                   padded_expected_programs,
                                   ignore_index=padding_index)
            loss.backward()
            if args.grad_clip > 0.:
                torch.nn.utils.clip_grad_norm_(robust_fill.parameters(),
                                               args.grad_clip)
            optimizer.step()

            # Debugging information
            if train_logger is not None:
                train_logger.add_scalar('loss', loss.item(), global_iter)

            if global_iter % checkpoint_step_size == 0:
                print('Checkpointing at batch {}'.format(global_iter))
                print('Loss: {}'.format(loss.item()))

                # note this code will not print correct if more than 1 printed
                if checkpoint_print_tensors:
                    temp = actual_programs.permute(
                        2, 0, 1)[:len(expected_programs[0]), :1, :]
                    tokens = torch.argmax(temp.permute(1, 0, 2), dim=-1)
                    tokens = tokens[0].tolist()
                    print_programs(expected_programs[0], tokens, train_logger,
                                   token_tables.token_op_table, global_iter)

                if checkpoint_filename is not None:
                    print('Saving to file {}'.format(checkpoint_filename))
                    torch.save(robust_fill.state_dict(), checkpoint_filename)
                print('Done checkpointing model')
            global_iter += 1