Exemplo n.º 1
0
def load_model(coach_path, model_path, args):
    if 'onehot' in coach_path:
        coach = ConvOneHotCoach.load(coach_path).to(device)
    elif 'gen' in coach_path:
        coach = RnnGenerator.load(coach_path).to(device)
    else:
        coach = ConvRnnCoach.load(coach_path).to(device)
    coach.max_raw_chars = args.max_raw_chars
    executor = Executor.load(model_path).to(device)
    executor_wrapper = ExecutorWrapper(coach, executor, coach.num_instructions,
                                       args.max_raw_chars, args.cheat)
    executor_wrapper.train(False)
    return executor_wrapper
Exemplo n.º 2
0
def load_model(coach_path, executor_path):
    if 'onehot' in coach_path:
        coach = ConvOneHotCoach.load(coach_path).to(device)
    elif 'gen' in coach_path:
        coach = RnnGenerator.load(coach_path).to(device)
    else:
        coach = ConvRnnCoach.load(coach_path).to(device)
    coach.max_raw_chars = 200
    executor = Executor.load(executor_path).to(device)
    executor_wrapper = ExecutorWrapper(coach, executor, coach.num_instructions,
                                       200, 0, 'full')
    executor_wrapper.train(False)
    return executor_wrapper
Exemplo n.º 3
0
    def load_model(self, coach_path, executor_paths, args):
        coach_rule_emb_size = getattr(args, "coach_rule_emb_size", 0)
        executor_rule_emb_size = getattr(args, "executor_rule_emb_size", 0)
        inst_dict_path = getattr(args, "inst_dict_path", None)
        coach_random_init = getattr(args, "coach_random_init", False)

        assert isinstance(executor_paths, dict)

        if isinstance(coach_path, str):
            if "onehot" in coach_path:
                coach = ConvOneHotCoach.load(coach_path).to(self.device)
            elif "gen" in coach_path:
                coach = RnnGenerator.load(coach_path).to(self.device)
            else:
                coach = ConvRnnCoach.rl_load(
                    coach_path,
                    coach_rule_emb_size,
                    inst_dict_path,
                    coach_random_init=coach_random_init,
                ).to(self.device)
        else:
            print("Sharing coaches.")
            coach = coach_path
        coach.max_raw_chars = args.max_raw_chars

        executors = {}
        for k, executor_path in executor_paths.items():
            executor = Executor.rl_load(executor_path, executor_rule_emb_size,
                                        inst_dict_path).to(self.device)
            executors[k] = executor

        executor_wrapper = MultiExecutorWrapper(
            coach,
            executors,
            coach.num_instructions,
            args.max_raw_chars,
            args.cheat,
            args.inst_mode,
        )
        executor_wrapper.train(False)
        return executor_wrapper
Exemplo n.º 4
0
# Copyright (c) Facebook, Inc. and its affiliates.
Exemplo n.º 5
0
def main():
    torch.backends.cudnn.benchmark = True

    parser = common_utils.Parser()
    parser.add_parser('main', get_main_parser())
    parser.add_parser('coach', ConvRnnCoach.get_arg_parser())

    args = parser.parse()
    parser.log()

    options = args['main']

    if not os.path.exists(options.model_folder):
        os.makedirs(options.model_folder)
    logger_path = os.path.join(options.model_folder, 'train.log')
    if not options.dev:
        sys.stdout = common_utils.Logger(logger_path)

    if options.dev:
        options.train_dataset = options.train_dataset.replace('train.', 'dev.')
        options.val_dataset = options.val_dataset.replace('val.', 'dev.')

    print('Args:\n%s\n' % pprint.pformat(vars(options)))

    if options.gpu < 0:
        device = torch.device('cpu')
    else:
        device = torch.device('cuda:%d' % options.gpu)

    common_utils.set_all_seeds(options.seed)

    model_args = args['coach']
    if options.coach_type == 'onehot':
        model = ConvOneHotCoach(model_args, 0, options.max_instruction_span,
                                options.num_resource_bin).to(device)
    elif options.coach_type in ['rnn', 'bow']:
        model = ConvRnnCoach(model_args, 0, options.max_instruction_span,
                             options.coach_type,
                             options.num_resource_bin).to(device)
    elif options.coach_type == 'rnn_gen':
        model = RnnGenerator(model_args, 0, options.max_instruction_span,
                             options.num_resource_bin).to(device)

    print(model)

    train_dataset = CoachDataset(
        options.train_dataset,
        options.moving_avg_decay,
        options.num_resource_bin,
        options.resource_bin_size,
        options.max_num_prev_cmds,
        model.inst_dict,
        options.max_instruction_span,
    )
    val_dataset = CoachDataset(
        options.val_dataset,
        options.moving_avg_decay,
        options.num_resource_bin,
        options.resource_bin_size,
        options.max_num_prev_cmds,
        model.inst_dict,
        options.max_instruction_span,
    )
    eval_dataset = CoachDataset(options.val_dataset,
                                options.moving_avg_decay,
                                options.num_resource_bin,
                                options.resource_bin_size,
                                options.max_num_prev_cmds,
                                model.inst_dict,
                                options.max_instruction_span,
                                num_instructions=model.args.num_pos_inst)

    if not options.dev:
        compute_cache(train_dataset)
        compute_cache(val_dataset)
        compute_cache(eval_dataset)

    if options.optim == 'adamax':
        optimizer = torch.optim.Adamax(model.parameters(),
                                       lr=options.lr,
                                       betas=(options.beta1, options.beta2))
    elif options.optim == 'adam':
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=options.lr,
                                     betas=(options.beta1, options.beta2))
    else:
        assert False, 'not supported'

    train_loader = DataLoader(
        train_dataset,
        options.batch_size,
        shuffle=True,
        num_workers=1,  # if options.dev else 10,
        pin_memory=(options.gpu >= 0))
    val_loader = DataLoader(
        val_dataset,
        options.batch_size,
        shuffle=False,
        num_workers=1,  # if options.dev else 10,
        pin_memory=(options.gpu >= 0))
    eval_loader = DataLoader(
        eval_dataset,
        options.batch_size,
        shuffle=False,
        num_workers=1,  #0 if options.dev else 10,
        pin_memory=(options.gpu >= 0))

    best_val_nll = float('inf')
    overfit_count = 0
    for epoch in range(1, options.epochs + 1):
        print('==========')
        train(model, device, optimizer, options.grad_clip, train_loader, epoch)
        with torch.no_grad(), common_utils.EvalMode(model):
            val_nll = evaluate(model, device, val_loader, epoch, 'val', False)
            eval_nll = evaluate(model, device, eval_loader, epoch, 'eval',
                                True)

        model_file = os.path.join(options.model_folder,
                                  'checkpoint%d.pt' % epoch)
        print('saving model to', model_file)
        model.save(model_file)

        if val_nll < best_val_nll:
            print('!!!New Best Model')
            overfit_count = 0
            best_val_nll = val_nll
            best_model_file = os.path.join(options.model_folder,
                                           'best_checkpoint.pt')
            print('saving best model to', best_model_file)
            model.save(best_model_file)
        else:
            overfit_count += 1
            if overfit_count == 2:
                break

    print('train DONE')
Exemplo n.º 6
0
if __name__ == '__main__':
    args = parse_args()
    print('args:')
    pprint.pprint(vars(args))

    os.environ['LUA_PATH'] = os.path.join(args.lua_files, '?.lua')
    print('lua path:', os.environ['LUA_PATH'])

    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)
    logger_path = os.path.join(args.save_dir, 'train.log')
    sys.stdout = Logger(logger_path)

    device = torch.device('cuda:%d' % args.gpu)
    coach = ConvRnnCoach.load(args.coach_path).to(device)
    coach.max_raw_chars = args.max_raw_chars
    executor = Executor.load(args.model_path).to(device)
    executor_wrapper = ExecutorWrapper(coach, executor, coach.num_instructions,
                                       args.max_raw_chars, args.cheat,
                                       args.inst_mode)
    executor_wrapper.train(False)

    game_option = get_game_option(args)
    ai1_option, ai2_option = get_ai_options(args, coach.num_instructions)

    context, act_dc = create_game(args.num_thread, ai1_option, ai2_option,
                                  game_option)
    context.start()
    dc = DataChannelManager([act_dc])
Exemplo n.º 7
0
# Copyright (c) Facebook, Inc. and its affiliates.
Exemplo n.º 8
0
# Copyright (c) Facebook, Inc. and its affiliates.
Exemplo n.º 9
0
# Copyright (c) Facebook, Inc. and its affiliates.