def run(args): util.print_with_time(str(args)) # save args model_folder = util.get_model_folder() args_filename = util.get_args_filename(model_folder) util.save_object(args, args_filename) # init models util.set_seed(args.seed) generative_model, inference_network, true_generative_model = \ load_or_init_models(args.load_model_folder, args.pcfg_path) if args.train_mode == 'relax': control_variate = models.ControlVariate(generative_model.grammar) # train if args.train_mode == 'ww': train_callback = train.TrainWakeWakeCallback( args.pcfg_path, model_folder, true_generative_model, args.logging_interval, args.checkpoint_interval, args.eval_interval) train.train_wake_wake(generative_model, inference_network, true_generative_model, args.batch_size, args.num_iterations, args.num_particles, train_callback) elif args.train_mode == 'ws': train_callback = train.TrainWakeSleepCallback( args.pcfg_path, model_folder, true_generative_model, args.logging_interval, args.checkpoint_interval, args.eval_interval) train.train_wake_sleep(generative_model, inference_network, true_generative_model, args.batch_size, args.num_iterations, args.num_particles, train_callback) elif args.train_mode == 'reinforce' or args.train_mode == 'vimco': train_callback = train.TrainIwaeCallback( args.pcfg_path, model_folder, true_generative_model, args.logging_interval, args.checkpoint_interval, args.eval_interval) train.train_iwae(args.train_mode, generative_model, inference_network, true_generative_model, args.batch_size, args.num_iterations, args.num_particles, train_callback) elif args.train_mode == 'relax': train_callback = train.TrainRelaxCallback( args.pcfg_path, model_folder, true_generative_model, args.logging_interval, args.checkpoint_interval, args.eval_interval) train.train_relax(generative_model, inference_network, control_variate, true_generative_model, args.batch_size, args.num_iterations, args.num_particles, train_callback) # save models and stats util.save_models(generative_model, inference_network, args.pcfg_path, model_folder) stats_filename = util.get_stats_filename(model_folder) util.save_object(train_callback, stats_filename)
def load_control_variate(model_folder='.', iteration=None): if iteration is None: suffix = '' else: suffix = iteration path = os.path.join(model_folder, 'c{}.pt'.format(suffix)) args = load_object(get_args_path(model_folder)) control_variate = models.ControlVariate(args.num_mixtures) control_variate.load_state_dict(torch.load(path)) print_with_time('Loaded from {}'.format(path))
def load_control_variate(model_folder='.', iteration=None): if iteration is None: suffix = '' else: suffix = iteration path = os.path.join(model_folder, 'c{}.pt'.format(suffix)) pcfg_path_path = os.path.join(model_folder, 'pcfg_path.txt') with open(pcfg_path_path) as f: pcfg_path = f.read() grammar, _ = read_pcfg(pcfg_path) control_variate = models.ControlVariate(grammar) control_variate.load_state_dict(torch.load(path)) print_with_time('Loaded from {}'.format(path))
def run(args): # set up args args.device = None if args.cuda and torch.cuda.is_available(): args.device = torch.device('cuda') else: args.device = torch.device('cpu') args.num_mixtures = 20 if args.init_near: args.init_mixture_logits = np.ones(args.num_mixtures) else: args.init_mixture_logits = np.array( list(reversed(2 * np.arange(args.num_mixtures)))) args.softmax_multiplier = 0.5 if args.train_mode == 'concrete': args.relaxed_one_hot = True args.temperature = 3 else: args.relaxed_one_hot = False args.temperature = None temp = np.arange(args.num_mixtures) + 5 true_p_mixture_probs = temp / np.sum(temp) args.true_mixture_logits = \ np.log(true_p_mixture_probs) / args.softmax_multiplier util.print_with_time(str(args)) # save args model_folder = util.get_model_folder() args_filename = util.get_args_path(model_folder) util.save_object(args, args_filename) # init models util.set_seed(args.seed) generative_model, inference_network, true_generative_model = \ util.init_models(args) if args.train_mode == 'relax': control_variate = models.ControlVariate(args.num_mixtures) # init dataloader obss_data_loader = torch.utils.data.DataLoader( true_generative_model.sample_obs(args.num_obss), batch_size=args.batch_size, shuffle=True) # train if args.train_mode == 'mws': train_callback = train.TrainMWSCallback(model_folder, true_generative_model, args.logging_interval, args.checkpoint_interval, args.eval_interval) train.train_mws(generative_model, inference_network, obss_data_loader, args.num_iterations, args.mws_memory_size, train_callback) if args.train_mode == 'ws': train_callback = train.TrainWakeSleepCallback( model_folder, true_generative_model, args.batch_size * args.num_particles, args.logging_interval, args.checkpoint_interval, args.eval_interval) train.train_wake_sleep(generative_model, inference_network, obss_data_loader, args.num_iterations, args.num_particles, train_callback) elif args.train_mode == 'ww': train_callback = train.TrainWakeWakeCallback(model_folder, true_generative_model, args.num_particles, args.logging_interval, args.checkpoint_interval, args.eval_interval) train.train_wake_wake(generative_model, inference_network, obss_data_loader, args.num_iterations, args.num_particles, train_callback) elif args.train_mode == 'dww': train_callback = train.TrainDefensiveWakeWakeCallback( model_folder, true_generative_model, args.num_particles, 0.2, args.logging_interval, args.checkpoint_interval, args.eval_interval) train.train_defensive_wake_wake(0.2, generative_model, inference_network, obss_data_loader, args.num_iterations, args.num_particles, train_callback) elif args.train_mode == 'reinforce' or args.train_mode == 'vimco': train_callback = train.TrainIwaeCallback( model_folder, true_generative_model, args.num_particles, args.train_mode, args.logging_interval, args.checkpoint_interval, args.eval_interval) train.train_iwae(args.train_mode, generative_model, inference_network, obss_data_loader, args.num_iterations, args.num_particles, train_callback) elif args.train_mode == 'concrete': train_callback = train.TrainConcreteCallback( model_folder, true_generative_model, args.num_particles, args.num_iterations, args.logging_interval, args.checkpoint_interval, args.eval_interval) train.train_iwae(args.train_mode, generative_model, inference_network, obss_data_loader, args.num_iterations, args.num_particles, train_callback) elif args.train_mode == 'relax': train_callback = train.TrainRelaxCallback(model_folder, true_generative_model, args.num_particles, args.logging_interval, args.checkpoint_interval, args.eval_interval) train.train_relax(generative_model, inference_network, control_variate, obss_data_loader, args.num_iterations, args.num_particles, train_callback) # save models and stats util.save_models(generative_model, inference_network, model_folder) if args.train_mode == 'relax': util.save_control_variate(control_variate, model_folder) stats_filename = util.get_stats_path(model_folder) util.save_object(train_callback, stats_filename)