def load_model(coach_path, model_path, args): if 'onehot' in coach_path: coach = ConvOneHotCoach.load(coach_path).to(device) elif 'gen' in coach_path: coach = RnnGenerator.load(coach_path).to(device) else: coach = ConvRnnCoach.load(coach_path).to(device) coach.max_raw_chars = args.max_raw_chars executor = Executor.load(model_path).to(device) executor_wrapper = ExecutorWrapper(coach, executor, coach.num_instructions, args.max_raw_chars, args.cheat) executor_wrapper.train(False) return executor_wrapper
def load_model(coach_path, executor_path): if 'onehot' in coach_path: coach = ConvOneHotCoach.load(coach_path).to(device) elif 'gen' in coach_path: coach = RnnGenerator.load(coach_path).to(device) else: coach = ConvRnnCoach.load(coach_path).to(device) coach.max_raw_chars = 200 executor = Executor.load(executor_path).to(device) executor_wrapper = ExecutorWrapper(coach, executor, coach.num_instructions, 200, 0, 'full') executor_wrapper.train(False) return executor_wrapper
def load_model(self, coach_path, executor_paths, args): coach_rule_emb_size = getattr(args, "coach_rule_emb_size", 0) executor_rule_emb_size = getattr(args, "executor_rule_emb_size", 0) inst_dict_path = getattr(args, "inst_dict_path", None) coach_random_init = getattr(args, "coach_random_init", False) assert isinstance(executor_paths, dict) if isinstance(coach_path, str): if "onehot" in coach_path: coach = ConvOneHotCoach.load(coach_path).to(self.device) elif "gen" in coach_path: coach = RnnGenerator.load(coach_path).to(self.device) else: coach = ConvRnnCoach.rl_load( coach_path, coach_rule_emb_size, inst_dict_path, coach_random_init=coach_random_init, ).to(self.device) else: print("Sharing coaches.") coach = coach_path coach.max_raw_chars = args.max_raw_chars executors = {} for k, executor_path in executor_paths.items(): executor = Executor.rl_load(executor_path, executor_rule_emb_size, inst_dict_path).to(self.device) executors[k] = executor executor_wrapper = MultiExecutorWrapper( coach, executors, coach.num_instructions, args.max_raw_chars, args.cheat, args.inst_mode, ) executor_wrapper.train(False) return executor_wrapper
# Copyright (c) Facebook, Inc. and its affiliates.
def main(): torch.backends.cudnn.benchmark = True parser = common_utils.Parser() parser.add_parser('main', get_main_parser()) parser.add_parser('coach', ConvRnnCoach.get_arg_parser()) args = parser.parse() parser.log() options = args['main'] if not os.path.exists(options.model_folder): os.makedirs(options.model_folder) logger_path = os.path.join(options.model_folder, 'train.log') if not options.dev: sys.stdout = common_utils.Logger(logger_path) if options.dev: options.train_dataset = options.train_dataset.replace('train.', 'dev.') options.val_dataset = options.val_dataset.replace('val.', 'dev.') print('Args:\n%s\n' % pprint.pformat(vars(options))) if options.gpu < 0: device = torch.device('cpu') else: device = torch.device('cuda:%d' % options.gpu) common_utils.set_all_seeds(options.seed) model_args = args['coach'] if options.coach_type == 'onehot': model = ConvOneHotCoach(model_args, 0, options.max_instruction_span, options.num_resource_bin).to(device) elif options.coach_type in ['rnn', 'bow']: model = ConvRnnCoach(model_args, 0, options.max_instruction_span, options.coach_type, options.num_resource_bin).to(device) elif options.coach_type == 'rnn_gen': model = RnnGenerator(model_args, 0, options.max_instruction_span, options.num_resource_bin).to(device) print(model) train_dataset = CoachDataset( options.train_dataset, options.moving_avg_decay, options.num_resource_bin, options.resource_bin_size, options.max_num_prev_cmds, model.inst_dict, options.max_instruction_span, ) val_dataset = CoachDataset( options.val_dataset, options.moving_avg_decay, options.num_resource_bin, options.resource_bin_size, options.max_num_prev_cmds, model.inst_dict, options.max_instruction_span, ) eval_dataset = CoachDataset(options.val_dataset, options.moving_avg_decay, options.num_resource_bin, options.resource_bin_size, options.max_num_prev_cmds, model.inst_dict, options.max_instruction_span, num_instructions=model.args.num_pos_inst) if not options.dev: compute_cache(train_dataset) compute_cache(val_dataset) compute_cache(eval_dataset) if options.optim == 'adamax': optimizer = torch.optim.Adamax(model.parameters(), lr=options.lr, betas=(options.beta1, options.beta2)) elif options.optim == 'adam': optimizer = torch.optim.Adam(model.parameters(), lr=options.lr, betas=(options.beta1, options.beta2)) else: assert False, 'not supported' train_loader = DataLoader( train_dataset, options.batch_size, shuffle=True, num_workers=1, # if options.dev else 10, pin_memory=(options.gpu >= 0)) val_loader = DataLoader( val_dataset, options.batch_size, shuffle=False, num_workers=1, # if options.dev else 10, pin_memory=(options.gpu >= 0)) eval_loader = DataLoader( eval_dataset, options.batch_size, shuffle=False, num_workers=1, #0 if options.dev else 10, pin_memory=(options.gpu >= 0)) best_val_nll = float('inf') overfit_count = 0 for epoch in range(1, options.epochs + 1): print('==========') train(model, device, optimizer, options.grad_clip, train_loader, epoch) with torch.no_grad(), common_utils.EvalMode(model): val_nll = evaluate(model, device, val_loader, epoch, 'val', False) eval_nll = evaluate(model, device, eval_loader, epoch, 'eval', True) model_file = os.path.join(options.model_folder, 'checkpoint%d.pt' % epoch) print('saving model to', model_file) model.save(model_file) if val_nll < best_val_nll: print('!!!New Best Model') overfit_count = 0 best_val_nll = val_nll best_model_file = os.path.join(options.model_folder, 'best_checkpoint.pt') print('saving best model to', best_model_file) model.save(best_model_file) else: overfit_count += 1 if overfit_count == 2: break print('train DONE')
if __name__ == '__main__': args = parse_args() print('args:') pprint.pprint(vars(args)) os.environ['LUA_PATH'] = os.path.join(args.lua_files, '?.lua') print('lua path:', os.environ['LUA_PATH']) if not os.path.exists(args.save_dir): os.makedirs(args.save_dir) logger_path = os.path.join(args.save_dir, 'train.log') sys.stdout = Logger(logger_path) device = torch.device('cuda:%d' % args.gpu) coach = ConvRnnCoach.load(args.coach_path).to(device) coach.max_raw_chars = args.max_raw_chars executor = Executor.load(args.model_path).to(device) executor_wrapper = ExecutorWrapper(coach, executor, coach.num_instructions, args.max_raw_chars, args.cheat, args.inst_mode) executor_wrapper.train(False) game_option = get_game_option(args) ai1_option, ai2_option = get_ai_options(args, coach.num_instructions) context, act_dc = create_game(args.num_thread, ai1_option, ai2_option, game_option) context.start() dc = DataChannelManager([act_dc])