Exemple #1
0
        ## size of latents flattened - dependent on architecture of vqvae
        #info['float_condition_size'] = 100*args.num_z
        ## 3x logistic needed for loss
        ## TODO - change loss
    else:
        print('loading model from: %s' % args.model_loadpath)
        model_dict = torch.load(args.model_loadpath)
        info = model_dict['info']
        model_base_filedir = os.path.split(args.model_loadpath)[0]
        model_base_filepath = os.path.join(model_base_filedir, args.savename)
        train_cnt = info['train_cnts'][-1]
        info['loaded_from'] = args.model_loadpath
    train_data_loader = AtariDataset(train_data_file,
                                     number_condition=args.number_condition,
                                     steps_ahead=1,
                                     batch_size=args.batch_size,
                                     norm_by=info['norm_by'])
    valid_data_loader = AtariDataset(valid_data_file,
                                     number_condition=args.number_condition,
                                     steps_ahead=1,
                                     batch_size=args.batch_size,
                                     norm_by=info['norm_by'])
    num_actions = train_data_loader.n_actions
    args.size_training_set = train_data_loader.num_examples
    hsize = train_data_loader.data_h
    wsize = train_data_loader.data_w
    # output mixtures should be 2*nr_logistic_mix + nr_logistic mix for each
    # decorelated channel
    info['num_channels'] = 2
    info['num_output_mixtures'] = (2 * args.nr_logistic_mix +
        sys.exit()

    output_savepath = model_loadpath.replace('.pt', '_samples')
    if not os.path.exists(output_savepath):
        os.makedirs(output_savepath)
    model_dict = torch.load(model_loadpath, map_location=lambda storage, loc: storage)
    info = model_dict['info']
    largs = info['args'][-1]

    run_num = 0
    train_data_file = largs.train_data_file
    valid_data_file = largs.train_data_file.replace('training', 'valid')

    train_data_loader = AtariDataset(
                                   train_data_file,
                                   number_condition=4,
                                   steps_ahead=1,
                                   batch_size=args.batch_size,
                                   norm_by=255.,)
    valid_data_loader = AtariDataset(
                                   valid_data_file,
                                   number_condition=4,
                                   steps_ahead=1,
                                   batch_size=largs.batch_size,
                                   norm_by=255.0,)

    num_actions = valid_data_loader.n_actions
    args.size_training_set = valid_data_loader.num_examples
    hsize = valid_data_loader.data_h
    wsize = valid_data_loader.data_w

    vqvae_model = VQVAE(num_clusters=largs.num_k,
    output_savepath = model_loadpath.replace('.pt', '_samples')
    if not os.path.exists(output_savepath):
        os.makedirs(output_savepath)
    model_dict = torch.load(model_loadpath,
                            map_location=lambda storage, loc: storage)
    info = model_dict['info']
    largs = info['args'][-1]

    run_num = 0
    train_data_file = largs.train_data_file
    valid_data_file = largs.train_data_file.replace('training', 'valid')

    train_data_loader = AtariDataset(
        train_data_file,
        number_condition=4,
        steps_ahead=1,
        batch_size=args.batch_size,
        norm_by=255.,
    )
    valid_data_loader = AtariDataset(
        valid_data_file,
        number_condition=4,
        steps_ahead=1,
        batch_size=args.batch_size,
        norm_by=255.0,
    )

    args.size_training_set = valid_data_loader.num_examples
    hsize = valid_data_loader.data_h
    wsize = valid_data_loader.data_w
Exemple #4
0
def init_train():
    train_data_file = args.train_data_file
    data_dir = os.path.split(train_data_file)[0]
    #valid_data_file = train_data_file.replace('training', 'valid')
    valid_data_file = '/usr/local/data/jhansen/planning/model_savedir/FRANKbootstrap_priorfreeway00/valid_set_small.npz'
    if args.model_loadpath == '':
        train_cnt = 0
        run_num = 0
        model_base_filedir = os.path.join(data_dir,
                                          args.savename + '%02d' % run_num)
        while os.path.exists(model_base_filedir):
            run_num += 1
            model_base_filedir = os.path.join(data_dir,
                                              args.savename + '%02d' % run_num)
        os.makedirs(model_base_filedir)
        model_base_filepath = os.path.join(model_base_filedir, args.savename)
        print("MODEL BASE FILEPATH", model_base_filepath)

        info = {
            'vq_train_cnts': [],
            'vq_train_losses_list': [],
            'vq_valid_cnts': [],
            'vq_valid_losses_list': [],
            'vq_save_times': [],
            'vq_last_save': 0,
            'vq_last_plot': 0,
            'NORM_BY': 255.0,
            'vq_model_loadpath': args.model_loadpath,
            'vq_model_base_filedir': model_base_filedir,
            'vq_model_base_filepath': model_base_filepath,
            'vq_train_data_file': args.train_data_file,
            'VQ_SAVENAME': args.savename,
            'DEVICE': DEVICE,
            'VQ_NUM_EXAMPLES_TO_TRAIN': args.num_examples_to_train,
            'NUM_Z': args.num_z,
            'NUM_K': args.num_k,
            'NR_LOGISTIC_MIX': args.nr_logistic_mix,
            'BETA': args.beta,
            'ALPHA_REC': args.alpha_rec,
            'ALPHA_ACT': args.alpha_act,
            'ALPHA_REW': args.alpha_rew,
            'VQ_BATCH_SIZE': args.batch_size,
            'NUMBER_CONDITION': args.number_condition,
            'VQ_LEARNING_RATE': args.learning_rate,
            'VQ_SAVE_EVERY': args.save_every,
            'VQ_MIN_BATCHES_BEFORE_SAVE': args.min_batches,
            'REWARD_SPACE': [-1, 0, 1],
            'action_space': [0, 1, 2],
        }

        ## size of latents flattened - dependent on architecture of vqvae
        #info['float_condition_size'] = 100*args.num_z
        ## 3x logistic needed for loss
        ## TODO - change loss
    else:
        print('loading model from: %s' % args.model_loadpath)
        model_dict = torch.load(args.model_loadpath)
        info = model_dict['vq_info']
        model_base_filedir = os.path.split(args.model_loadpath)[0]
        model_base_filepath = os.path.join(model_base_filedir, args.savename)
        train_cnt = info['vq_train_cnts'][-1]
        info['loaded_from'] = args.model_loadpath
        info['VQ_BATCH_SIZE'] = args.batch_size
        #if 'reward_weights' not in info.keys():
        #    info['reward_weights'] = [1,100]
    train_data_loader = AtariDataset(train_data_file,
                                     number_condition=info['NUMBER_CONDITION'],
                                     steps_ahead=1,
                                     batch_size=info['VQ_BATCH_SIZE'],
                                     norm_by=info['NORM_BY'],
                                     unique_actions=info['action_space'],
                                     unique_rewards=info['REWARD_SPACE'])
    train_data_loader.plot_dataset()
    valid_data_loader = AtariDataset(valid_data_file,
                                     number_condition=info['NUMBER_CONDITION'],
                                     steps_ahead=1,
                                     batch_size=info['VQ_BATCH_SIZE'],
                                     norm_by=info['NORM_BY'],
                                     unique_actions=info['action_space'],
                                     unique_rewards=info['REWARD_SPACE'])
    #info['num_actions'] = train_data_loader.n_actions
    info['num_actions'] = len(info['action_space'])
    info['num_rewards'] = len(info['REWARD_SPACE'])
    info['size_training_set'] = train_data_loader.num_examples
    info['hsize'] = train_data_loader.data_h
    info['wsize'] = train_data_loader.data_w

    #reward_loss_weight = torch.ones(info['num_rewards']).to(DEVICE)
    #for i, w  in enumerate(info['reward_weights']):
    #    reward_loss_weight[i] *= w
    actions_weight = 1 - np.array(train_data_loader.percentages_actions)
    rewards_weight = 1 - np.array(train_data_loader.percentages_rewards)
    actions_weight = torch.FloatTensor(actions_weight).to(DEVICE)
    rewards_weight = torch.FloatTensor(rewards_weight).to(DEVICE)
    info['actions_weight'] = actions_weight
    info['rewards_weight'] = rewards_weight

    # output mixtures should be 2*nr_logistic_mix + nr_logistic mix for each
    # decorelated channel
    info['num_channels'] = 2
    info['num_output_mixtures'] = (2 * args.nr_logistic_mix +
                                   args.nr_logistic_mix) * info['num_channels']
    nmix = int(info['num_output_mixtures'] / 2)
    info['nmix'] = nmix
    vqvae_model = VQVAE(
        num_clusters=info['NUM_K'],
        encoder_output_size=info['NUM_Z'],
        num_output_mixtures=info['num_output_mixtures'],
        in_channels_size=info['NUMBER_CONDITION'],
        n_actions=info['num_actions'],
        int_reward=info['num_rewards'],
    ).to(DEVICE)

    print('using args', args)
    parameters = list(vqvae_model.parameters())
    opt = optim.Adam(parameters, lr=info['VQ_LEARNING_RATE'])
    if args.model_loadpath != '':
        print("loading weights from:%s" % args.model_loadpath)
        vqvae_model.load_state_dict(model_dict['vqvae_state_dict'])
        opt.load_state_dict(model_dict['vq_optimizer'])
        vqvae_model.embedding = model_dict['vq_embedding']

    #args.pred_output_size = 1*80*80
    ## 10 is result of structure of network
    #args.z_input_size = 10*10*args.num_z
    train_cnt = train_vqvae(train_cnt, vqvae_model, opt, info,
                            train_data_loader, valid_data_loader)
Exemple #5
0
        episode_batch = (df['states'], df['actions'], df['rewards'],
                         df['values'], df['next_states'], fake_terminals,
                         False, np.arange(nn))
        #states, actions, rewards, values, next_states, terminals, reset, relative_indexes = data

    else:
        train_data_file = vq_largs.train_data_file
        valid_data_file = vq_largs.train_data_file.replace('training', 'valid')

        if args.train:
            print("USING TRAINING DATA")
            name = 'train'
            data_loader = AtariDataset(
                train_data_file,
                number_condition=4,
                steps_ahead=1,
                batch_size=args.batch_size,
                norm_by=255.,
            )
            episode_batch, episode_index, episode_reward = data_loader.get_entire_episode(
                diff=False, limit=args.limit, min_reward=args.min_reward)
        else:
            name = 'valid'
            data_loader = AtariDataset(
                valid_data_file,
                number_condition=4,
                steps_ahead=1,
                batch_size=args.batch_size,
                norm_by=255.0,
            )
            episode_batch, episode_index, episode_reward = data_loader.get_entire_episode(