Пример #1
0
def run(args):
    util.print_with_time(str(args))

    # save args
    model_folder = util.get_model_folder()
    args_filename = util.get_args_filename(model_folder)
    util.save_object(args, args_filename)

    # init models
    util.set_seed(args.seed)
    generative_model, inference_network, true_generative_model = \
        load_or_init_models(args.load_model_folder, args.pcfg_path)
    if args.train_mode == 'relax':
        control_variate = models.ControlVariate(generative_model.grammar)

    # train
    if args.train_mode == 'ww':
        train_callback = train.TrainWakeWakeCallback(
            args.pcfg_path, model_folder, true_generative_model,
            args.logging_interval, args.checkpoint_interval,
            args.eval_interval)
        train.train_wake_wake(generative_model, inference_network,
                              true_generative_model, args.batch_size,
                              args.num_iterations, args.num_particles,
                              train_callback)
    elif args.train_mode == 'ws':
        train_callback = train.TrainWakeSleepCallback(
            args.pcfg_path, model_folder, true_generative_model,
            args.logging_interval, args.checkpoint_interval,
            args.eval_interval)
        train.train_wake_sleep(generative_model, inference_network,
                               true_generative_model, args.batch_size,
                               args.num_iterations, args.num_particles,
                               train_callback)
    elif args.train_mode == 'reinforce' or args.train_mode == 'vimco':
        train_callback = train.TrainIwaeCallback(
            args.pcfg_path, model_folder, true_generative_model,
            args.logging_interval, args.checkpoint_interval,
            args.eval_interval)
        train.train_iwae(args.train_mode, generative_model, inference_network,
                         true_generative_model, args.batch_size,
                         args.num_iterations, args.num_particles,
                         train_callback)
    elif args.train_mode == 'relax':
        train_callback = train.TrainRelaxCallback(
            args.pcfg_path, model_folder, true_generative_model,
            args.logging_interval, args.checkpoint_interval,
            args.eval_interval)
        train.train_relax(generative_model, inference_network, control_variate,
                          true_generative_model, args.batch_size,
                          args.num_iterations, args.num_particles,
                          train_callback)

    # save models and stats
    util.save_models(generative_model, inference_network, args.pcfg_path,
                     model_folder)
    stats_filename = util.get_stats_filename(model_folder)
    util.save_object(train_callback, stats_filename)
Пример #2
0
def load_control_variate(model_folder='.', iteration=None):
    if iteration is None:
        suffix = ''
    else:
        suffix = iteration
    path = os.path.join(model_folder, 'c{}.pt'.format(suffix))
    args = load_object(get_args_path(model_folder))
    control_variate = models.ControlVariate(args.num_mixtures)
    control_variate.load_state_dict(torch.load(path))
    print_with_time('Loaded from {}'.format(path))
Пример #3
0
Файл: util.py Проект: yyht/rrws
def load_control_variate(model_folder='.', iteration=None):
    if iteration is None:
        suffix = ''
    else:
        suffix = iteration
    path = os.path.join(model_folder, 'c{}.pt'.format(suffix))
    pcfg_path_path = os.path.join(model_folder, 'pcfg_path.txt')
    with open(pcfg_path_path) as f:
        pcfg_path = f.read()
    grammar, _ = read_pcfg(pcfg_path)
    control_variate = models.ControlVariate(grammar)
    control_variate.load_state_dict(torch.load(path))
    print_with_time('Loaded from {}'.format(path))
Пример #4
0
def run(args):
    # set up args
    args.device = None
    if args.cuda and torch.cuda.is_available():
        args.device = torch.device('cuda')
    else:
        args.device = torch.device('cpu')
    args.num_mixtures = 20
    if args.init_near:
        args.init_mixture_logits = np.ones(args.num_mixtures)
    else:
        args.init_mixture_logits = np.array(
            list(reversed(2 * np.arange(args.num_mixtures))))
    args.softmax_multiplier = 0.5
    if args.train_mode == 'concrete':
        args.relaxed_one_hot = True
        args.temperature = 3
    else:
        args.relaxed_one_hot = False
        args.temperature = None
    temp = np.arange(args.num_mixtures) + 5
    true_p_mixture_probs = temp / np.sum(temp)
    args.true_mixture_logits = \
        np.log(true_p_mixture_probs) / args.softmax_multiplier
    util.print_with_time(str(args))

    # save args
    model_folder = util.get_model_folder()
    args_filename = util.get_args_path(model_folder)
    util.save_object(args, args_filename)

    # init models
    util.set_seed(args.seed)
    generative_model, inference_network, true_generative_model = \
        util.init_models(args)
    if args.train_mode == 'relax':
        control_variate = models.ControlVariate(args.num_mixtures)

    # init dataloader
    obss_data_loader = torch.utils.data.DataLoader(
        true_generative_model.sample_obs(args.num_obss),
        batch_size=args.batch_size,
        shuffle=True)

    # train
    if args.train_mode == 'mws':
        train_callback = train.TrainMWSCallback(model_folder,
                                                true_generative_model,
                                                args.logging_interval,
                                                args.checkpoint_interval,
                                                args.eval_interval)
        train.train_mws(generative_model, inference_network, obss_data_loader,
                        args.num_iterations, args.mws_memory_size,
                        train_callback)
    if args.train_mode == 'ws':
        train_callback = train.TrainWakeSleepCallback(
            model_folder, true_generative_model,
            args.batch_size * args.num_particles, args.logging_interval,
            args.checkpoint_interval, args.eval_interval)
        train.train_wake_sleep(generative_model, inference_network,
                               obss_data_loader, args.num_iterations,
                               args.num_particles, train_callback)
    elif args.train_mode == 'ww':
        train_callback = train.TrainWakeWakeCallback(model_folder,
                                                     true_generative_model,
                                                     args.num_particles,
                                                     args.logging_interval,
                                                     args.checkpoint_interval,
                                                     args.eval_interval)
        train.train_wake_wake(generative_model, inference_network,
                              obss_data_loader, args.num_iterations,
                              args.num_particles, train_callback)
    elif args.train_mode == 'dww':
        train_callback = train.TrainDefensiveWakeWakeCallback(
            model_folder, true_generative_model, args.num_particles, 0.2,
            args.logging_interval, args.checkpoint_interval,
            args.eval_interval)
        train.train_defensive_wake_wake(0.2, generative_model,
                                        inference_network, obss_data_loader,
                                        args.num_iterations,
                                        args.num_particles, train_callback)
    elif args.train_mode == 'reinforce' or args.train_mode == 'vimco':
        train_callback = train.TrainIwaeCallback(
            model_folder, true_generative_model, args.num_particles,
            args.train_mode, args.logging_interval, args.checkpoint_interval,
            args.eval_interval)
        train.train_iwae(args.train_mode, generative_model, inference_network,
                         obss_data_loader, args.num_iterations,
                         args.num_particles, train_callback)
    elif args.train_mode == 'concrete':
        train_callback = train.TrainConcreteCallback(
            model_folder, true_generative_model, args.num_particles,
            args.num_iterations, args.logging_interval,
            args.checkpoint_interval, args.eval_interval)
        train.train_iwae(args.train_mode, generative_model, inference_network,
                         obss_data_loader, args.num_iterations,
                         args.num_particles, train_callback)
    elif args.train_mode == 'relax':
        train_callback = train.TrainRelaxCallback(model_folder,
                                                  true_generative_model,
                                                  args.num_particles,
                                                  args.logging_interval,
                                                  args.checkpoint_interval,
                                                  args.eval_interval)
        train.train_relax(generative_model, inference_network, control_variate,
                          obss_data_loader, args.num_iterations,
                          args.num_particles, train_callback)

    # save models and stats
    util.save_models(generative_model, inference_network, model_folder)
    if args.train_mode == 'relax':
        util.save_control_variate(control_variate, model_folder)
    stats_filename = util.get_stats_path(model_folder)
    util.save_object(train_callback, stats_filename)