Exemple #1
0
def finetune():
    setup()
    # Create a batch training environment that will also preprocess text
    vocab = read_vocab(TRAIN_VOCAB)
    tok = Tokenizer(vocab=vocab, encoding_length=args.maxInput)

    if args.fast_train:
        feat_dict = read_img_features(features_fast)
    else:
        feat_dict = read_img_features(features)

    candidate_dict = utils.read_candidates(CANDIDATE_FEATURES)
    featurized_scans = set(
        [key.split("_")[0] for key in list(feat_dict.keys())])

    train_env = R2RBatch(feat_dict,
                         candidate_dict,
                         batch_size=args.batchSize,
                         splits=['train'],
                         tokenizer=tok)
    print("The finetune data_size is : %d\n" % train_env.size())
    val_envs = {
        split:
        (R2RBatch(feat_dict,
                  candidate_dict,
                  batch_size=args.batchSize,
                  splits=[split],
                  tokenizer=tok), Evaluation([split], featurized_scans, tok))
        for split in ['train', 'val_seen', 'val_unseen']
    }

    train(train_env, tok, args.iters, val_envs=val_envs)
Exemple #2
0
def hard_negative():
    setup()
    # Create a batch training environment that will also preprocess text
    vocab = read_vocab(TRAIN_VOCAB)
    tok = Tokenizer(vocab=vocab, encoding_length=args.maxInput)

    if args.fast_train:
        feat_dict = read_img_features(features_fast)
    else:
        feat_dict = read_img_features(features)

    candidate_dict = utils.read_candidates(CANDIDATE_FEATURES)

    gt_train_env, gt_val_seen_env, gt_val_unseen_env = gt_envs = list(
        R2RBatch(feat_dict,
                 candidate_dict,
                 batch_size=args.batchSize,
                 splits=[split],
                 tokenizer=tok)
        for split in ['train', 'val_seen', 'val_unseen'])
    neg_train_env, neg_val_seen_env, neg_val_unseen_env = neg_envs = list(
        R2RBatch(feat_dict,
                 candidate_dict,
                 batch_size=args.batchSize,
                 splits=[split + "_instneg", split + "_pathneg"],
                 tokenizer=tok)
        for split in ['train', 'val_seen', 'val_unseen'])
    arbiter_train_env, arbiter_val_seen_env, arbiter_val_unseen_env = (
        ArbiterBatch(gt_env,
                     neg_env,
                     args.batchSize // 2,
                     args.batchSize // 2,
                     feat_dict,
                     candidate_dict,
                     batch_size=args.batchSize,
                     splits=[],
                     tokenizer=tok)
        for gt_env, neg_env in zip(gt_envs, neg_envs))
    train_arbiter(arbiter_train_env,
                  tok,
                  args.iters,
                  val_envs={
                      'train': arbiter_train_env,
                      'val_seen': arbiter_val_seen_env,
                      'val_unseen': arbiter_val_unseen_env,
                  })
Exemple #3
0
def train_val_augment():
    """
    Train the listener with the augmented data
    """
    setup()

    # Create a batch training environment that will also preprocess text
    vocab = read_vocab(TRAIN_VOCAB)
    tok = Tokenizer(vocab=vocab, encoding_length=args.maxInput)

    # Load the env img features
    feat_dict = read_img_features(features)
    featurized_scans = set(
        [key.split("_")[0] for key in list(feat_dict.keys())])

    # Load the augmentation data
    aug_path = args.aug

    # Create the training environment
    aug_env = R2RBatch(feat_dict,
                       batch_size=args.batchSize,
                       splits=[aug_path],
                       tokenizer=tok,
                       name='aug')

    # import sys
    # sys.exit()
    train_env = R2RBatch(feat_dict,
                         batch_size=args.batchSize,
                         splits=['train'],
                         tokenizer=tok)

    # Printing out the statistics of the dataset
    stats = train_env.get_statistics()
    print("The training data_size is : %d" % train_env.size())
    print("The average instruction length of the dataset is %0.4f." %
          (stats['length']))
    print("The average action length of the dataset is %0.4f." %
          (stats['path']))
    stats = aug_env.get_statistics()
    print("The augmentation data size is %d" % aug_env.size())
    print("The average instruction length of the dataset is %0.4f." %
          (stats['length']))
    print("The average action length of the dataset is %0.4f." %
          (stats['path']))

    # Setup the validation data
    val_envs = {
        split:
        (R2RBatch(feat_dict,
                  batch_size=args.batchSize,
                  splits=[split],
                  tokenizer=tok), Evaluation([split], featurized_scans, tok))
        for split in ['train', 'val_seen', 'val_unseen']
    }

    # Start training
    train(train_env, tok, args.iters, val_envs=val_envs, aug_env=aug_env)
Exemple #4
0
def train_val(test_only=False):
    ''' Train on the training set, and validate on seen and unseen splits. '''
    setup()
    vocab = read_vocab(TRAIN_VOCAB)
    tok = Tokenizer(vocab=vocab, encoding_length=args.maxInput)

    feat_dict = read_img_features(features, test_only=test_only)

    if test_only:
        featurized_scans = None
        val_env_names = ['val_train_seen']
    else:
        featurized_scans = set(
            [key.split("_")[0] for key in list(feat_dict.keys())])
        val_env_names = ['val_train_seen', 'val_seen', 'val_unseen']

    if not args.test_obj:
        print('Loading compact pano-caffe object features ... (~3 seconds)')
        import pickle as pkl
        with open('img_features/objects/pano_object_class.pkl', 'rb') as f_pc:
            pano_caffe = pkl.load(f_pc)
    else:
        pano_caffe = None

    train_env = R2RBatch(feat_dict,
                         pano_caffe,
                         batch_size=args.batchSize,
                         splits=['train'],
                         tokenizer=tok)
    from collections import OrderedDict

    if args.submit:
        val_env_names.append('test')

    val_envs = OrderedDict(((split, (R2RBatch(feat_dict,
                                              pano_caffe,
                                              batch_size=args.batchSize,
                                              splits=[split],
                                              tokenizer=tok),
                                     Evaluation([split], featurized_scans,
                                                tok)))
                            for split in val_env_names))

    if args.train == 'listener':
        train(train_env, tok, args.iters, val_envs=val_envs)
    elif args.train == 'validlistener':
        if args.beam:
            beam_valid(train_env, tok, val_envs=val_envs)
        else:
            valid(train_env, tok, val_envs=val_envs)
    elif args.train == 'speaker':
        train_speaker(train_env, tok, args.iters, val_envs=val_envs)
    elif args.train == 'validspeaker':
        valid_speaker(tok, val_envs)
    else:
        assert False
Exemple #5
0
def train_val():
    ''' Train on the training set, and validate on seen and unseen splits. '''
    # args.fast_train = True
    setup()
    # Create a batch training environment that will also preprocess text
    vocab = read_vocab(TRAIN_VOCAB)
    tok = Tokenizer(vocab=vocab, encoding_length=args.maxInput)

    if args.fast_train:
        feat_dict = read_img_features(features_fast)
    else:
        feat_dict = read_img_features(features)

    candidate_dict = utils.read_candidates(CANDIDATE_FEATURES)
    featurized_scans = set(
        [key.split("_")[0] for key in list(feat_dict.keys())])

    train_env = R2RBatch(feat_dict,
                         candidate_dict,
                         batch_size=args.batchSize,
                         splits=['train'],
                         tokenizer=tok)
    from collections import OrderedDict
    val_envs = OrderedDict(((split, (R2RBatch(feat_dict,
                                              candidate_dict,
                                              batch_size=args.batchSize,
                                              splits=[split],
                                              tokenizer=tok),
                                     Evaluation([split], featurized_scans,
                                                tok)))
                            for split in ['val_seen', 'val_unseen', 'train']))

    if args.train == 'listener':
        train(train_env, tok, args.iters, val_envs=val_envs)
    elif args.train == 'validlistener':
        valid(train_env, tok, args.iters, val_envs=val_envs)
    elif args.train == 'speaker':
        train_speaker(train_env, tok, args.iters, val_envs=val_envs)
    elif args.train == 'validspeaker':
        valid_speaker(tok, val_envs)
    else:
        assert False
Exemple #6
0
def train_val():
    ''' Train on the training set, and validate on seen and unseen splits. '''
    # args.fast_train = True
    setup()
    # Create a batch training environment that will also preprocess text
    vocab = read_vocab(TRAIN_VOCAB)
    tok = Tokenizer(vocab=vocab, encoding_length=args.maxInput)

    feat_dict = read_img_features(features)

    featurized_scans = set([key.split("_")[0] for key in list(feat_dict.keys())])

    train_env = R2RBatch(feat_dict, batch_size=args.batchSize, splits=['train'], tokenizer=tok)
    from collections import OrderedDict

    val_env_names = ['val_unseen', 'val_seen']
    if args.submit:
        val_env_names.append('test')
    else:
        pass
        #val_env_names.append('train')

    if not args.beam:
        val_env_names.append("train")

    val_envs = OrderedDict(
        ((split,
          (R2RBatch(feat_dict, batch_size=args.batchSize, splits=[split], tokenizer=tok),
           Evaluation([split], featurized_scans, tok))
          )
         for split in val_env_names
         )
    )

    if args.train == 'listener':
        train(train_env, tok, args.iters, val_envs=val_envs)
    elif args.train == 'vae_agent':
        train_vae_agent(train_env, tok, args.iters, val_envs=val_envs)
    elif args.train == 'validlistener':
        if args.beam:
            beam_valid(train_env, tok, val_envs=val_envs)
        else:
            valid(train_env, tok, val_envs=val_envs)
    elif args.train == 'speaker':
        train_speaker(train_env, tok, args.iters, val_envs=val_envs)
    elif args.train == 'validspeaker':
        valid_speaker(train_env, tok, val_envs)
    elif args.train == 'inferspeaker':
        unseen_env = R2RBatch(feat_dict, batch_size=args.batchSize, splits=['tasks/R2R/data/aug_paths_test.json'], tokenizer=None)
        infer_speaker(unseen_env, tok)
    else:
        assert False
Exemple #7
0
def train_vae():
    """Train vae for sub-policy(z->policy)"""
    setup()
    vocab = read_vocab(TRAIN_VOCAB)
    tok = Tokenizer(vocab=vocab, encoding_length=args.maxInput)
    feat_dict = read_img_features(features)
    featurized_scans = set([key.split("_")[0] for key in list(feat_dict.keys())])
    # Create a batch training environment that will also preprocess text
    train_env = R2RBatch(feat_dict, batch_size=args.batchSize, splits=['sub_train'],tokenizer=tok)
    writer = SummaryWriter(logdir=log_dir)

    obs_dim = train_env.feature_size+args.angle_feat_size
    # TODO: latent_dim ablation
    path_len = 2 # fix path_len = 2, total path_len = 6
    vae = BaseVAE(train_env, tok, obs_dim, args.vae_latent_dim).cuda()
    vae.train()
Exemple #8
0
def train_val_augment(test_only=False):
    """
    Train the listener with the augmented data
    """
    setup()

    # Create a batch training environment that will also preprocess text
    tok_bert = get_tokenizer(args)

    # Load the env img features
    feat_dict = read_img_features(features, test_only=test_only)

    if test_only:
        featurized_scans = None
        val_env_names = ['val_train_seen']
    else:
        featurized_scans = set(
            [key.split("_")[0] for key in list(feat_dict.keys())])
        val_env_names = ['val_train_seen', 'val_seen', 'val_unseen']

    # Load the augmentation data
    aug_path = args.aug
    # Create the training environment
    train_env = R2RBatch(feat_dict,
                         batch_size=args.batchSize,
                         splits=['train'],
                         tokenizer=tok_bert)
    aug_env = R2RBatch(feat_dict,
                       batch_size=args.batchSize,
                       splits=[aug_path],
                       tokenizer=tok_bert,
                       name='aug')

    # Setup the validation data
    val_envs = {
        split: (R2RBatch(feat_dict,
                         batch_size=args.batchSize,
                         splits=[split],
                         tokenizer=tok_bert),
                Evaluation([split], featurized_scans, tok_bert))
        for split in val_env_names
    }

    # Start training
    train(train_env, tok_bert, args.iters, val_envs=val_envs, aug_env=aug_env)
Exemple #9
0
def train_val(test_only=False):
    ''' Train on the training set, and validate on seen and unseen splits. '''
    setup()
    tok = get_tokenizer(args)

    feat_dict = read_img_features(features, test_only=test_only)

    if test_only:
        featurized_scans = None
        val_env_names = ['val_train_seen']
    else:
        featurized_scans = set(
            [key.split("_")[0] for key in list(feat_dict.keys())])
        val_env_names = ['val_train_seen', 'val_seen', 'val_unseen']

    train_env = R2RBatch(feat_dict,
                         batch_size=args.batchSize,
                         splits=['train'],
                         tokenizer=tok)
    from collections import OrderedDict

    if args.submit:
        val_env_names.append('test')
    else:
        pass

    val_envs = OrderedDict(((split, (R2RBatch(feat_dict,
                                              batch_size=args.batchSize,
                                              splits=[split],
                                              tokenizer=tok),
                                     Evaluation([split], featurized_scans,
                                                tok)))
                            for split in val_env_names))

    if args.train == 'listener':
        train(train_env, tok, args.iters, val_envs=val_envs)
    elif args.train == 'validlistener':
        valid(train_env, tok, val_envs=val_envs)
    else:
        assert False
Exemple #10
0
def create_augment_data():
    setup()

    # Create a batch training environment that will also preprocess text
    vocab = read_vocab(TRAIN_VOCAB)
    tok = Tokenizer(vocab=vocab, encoding_length=args.maxInput)

    # Load features
    feat_dict = read_img_features(features)
    candidate_dict = utils.read_candidates(CANDIDATE_FEATURES)

    # The datasets to be augmented
    print("Start to augment the data")
    aug_envs = []
    # aug_envs.append(
    #     R2RBatch(
    #         feat_dict, candidate_dict, batch_size=args.batchSize, splits=['train'], tokenizer=tok
    #     )
    # )
    # aug_envs.append(
    #     SemiBatch(False, 'tasks/R2R/data/all_paths_46_removetrain.json',
    #         feat_dict, candidate_dict, batch_size=args.batchSize, splits=['train', 'val_seen'], tokenizer=tok)
    # )
    aug_envs.append(
        SemiBatch(False,
                  'tasks/R2R/data/all_paths_46_removevalunseen.json',
                  "unseen",
                  feat_dict,
                  candidate_dict,
                  batch_size=args.batchSize,
                  splits=['val_unseen'],
                  tokenizer=tok))
    aug_envs.append(
        SemiBatch(False,
                  'tasks/R2R/data/all_paths_46_removetest.json',
                  "test",
                  feat_dict,
                  candidate_dict,
                  batch_size=args.batchSize,
                  splits=['test'],
                  tokenizer=tok))
    # aug_envs.append(
    #     R2RBatch(
    #         feat_dict, candidate_dict, batch_size=args.batchSize, splits=['val_seen'], tokenizer=tok
    #     )
    # )
    # aug_envs.append(
    #     R2RBatch(
    #         feat_dict, candidate_dict, batch_size=args.batchSize, splits=['val_unseen'], tokenizer=tok
    #     )
    # )

    for snapshot in os.listdir(os.path.join(log_dir, 'state_dict')):
        # if snapshot != "best_val_unseen_bleu":  # Select a particular snapshot to process. (O/w, it will make for every snapshot)
        if snapshot != "best_val_unseen_bleu":
            continue

        # Create Speaker
        listner = Seq2SeqAgent(aug_envs[0], "", tok, args.maxAction)
        speaker = Speaker(aug_envs[0], listner, tok)

        # Load Weight
        load_iter = speaker.load(os.path.join(log_dir, 'state_dict', snapshot))
        print("Load from iter %d" % (load_iter))

        # Augment the env from aug_envs
        for aug_env in aug_envs:
            speaker.env = aug_env

            # Create the aug data
            import tqdm
            path2inst = speaker.get_insts(beam=args.beam, wrapper=tqdm.tqdm)
            data = []
            for datum in aug_env.fake_data:
                datum = datum.copy()
                path_id = datum['path_id']
                if path_id in path2inst:
                    datum['instructions'] = [
                        tok.decode_sentence(path2inst[path_id])
                    ]
                    datum.pop('instr_encoding')  # Remove Redundant keys
                    datum.pop('instr_id')
                    data.append(datum)

            print("Totally, %d data has been generated for snapshot %s." %
                  (len(data), snapshot))
            print("Average Length %0.4f" % utils.average_length(path2inst))
            print(datum)  # Print a Sample

            # Save the data
            import json
            os.makedirs(os.path.join(log_dir, 'aug_data'), exist_ok=True)
            beam_tag = "_beam" if args.beam else ""
            json.dump(data,
                      open(
                          os.path.join(
                              log_dir, 'aug_data', '%s_%s%s.json' %
                              (snapshot, aug_env.name, beam_tag)), 'w'),
                      sort_keys=True,
                      indent=4,
                      separators=(',', ': '))
Exemple #11
0
def train_val():
    ''' Train on the training set, and validate on seen and unseen splits. '''
    # args.fast_train = True
    setup()
    # Create a batch training environment that will also preprocess text
    vocab = read_vocab(TRAIN_VOCAB)
    tok = Tokenizer(vocab=vocab, encoding_length=args.maxInput)

    feat_dict = read_img_features(features)

    featurized_scans = set(
        [key.split("_")[0] for key in list(feat_dict.keys())])

    if not args.test_obj:
        print('Loading compact pano-caffe object features ... (~3 seconds)')
        import pickle as pkl
        with open(
                '/egr/research-hlr/joslin/Matterdata/v1/scans/img_features/pano_object_class.pkl',
                'rb') as f_pc:
            pano_caffe = pkl.load(f_pc)
    else:
        pano_caffe = None

    train_env = R2RBatch(feat_dict,
                         pano_caffe,
                         batch_size=args.batchSize,
                         splits=['train'],
                         tokenizer=tok)
    from collections import OrderedDict

    val_env_names = ['val_unseen', 'val_seen']
    if args.submit:
        val_env_names.append('test')
    else:
        pass
        # if you want to test "train", just uncomment this
        #val_env_names.append('train')

    if not args.beam:
        val_env_names.append("train")

    val_envs = OrderedDict(((split, (R2RBatch(feat_dict,
                                              pano_caffe,
                                              batch_size=args.batchSize,
                                              splits=[split],
                                              tokenizer=tok),
                                     Evaluation([split], featurized_scans,
                                                tok)))
                            for split in val_env_names))

    # import sys
    # sys.exit()
    if args.train == 'listener':
        train(train_env, tok, args.iters, val_envs=val_envs)
    elif args.train == 'validlistener':
        if args.beam:
            beam_valid(train_env, tok, val_envs=val_envs)
        else:
            valid(train_env, tok, val_envs=val_envs)
    elif args.train == 'speaker':
        train_speaker(train_env, tok, args.iters, val_envs=val_envs)
    elif args.train == 'validspeaker':
        valid_speaker(tok, val_envs)
    else:
        assert False
Exemple #12
0
def meta_filter():
    """
    Train the listener with the augmented data
    """
    setup()
    # Create a batch training environment that will also preprocess text
    vocab = read_vocab(TRAIN_VOCAB)
    tok = Tokenizer(vocab=vocab, encoding_length=args.maxInput)

    if args.fast_train:
        feat_dict = read_img_features(features_fast)
    else:
        feat_dict = read_img_features(features)
    candidate_dict = utils.read_candidates(CANDIDATE_FEATURES)
    featurized_scans = set(
        [key.split("_")[0] for key in list(feat_dict.keys())])

    # Load the augmentation data
    if args.aug is None:  # If aug is specified, load the "aug"
        speaker_snap_name = "adam_drop6_correctsave"
        print("Loading from %s" % speaker_snap_name)
        aug_path = "snap/speaker/long/%s/aug_data/best_val_unseen_loss.json" % speaker_snap_name
    else:  # Load the path from args
        aug_path = args.aug

    # Create the training environment
    aug_env = R2RBatch(feat_dict,
                       candidate_dict,
                       batch_size=args.batchSize,
                       splits=[aug_path],
                       tokenizer=tok)
    train_env = R2RBatch(feat_dict,
                         candidate_dict,
                         batch_size=args.batchSize,
                         splits=['train@3333'],
                         tokenizer=tok)
    print("The augmented data_size is : %d" % train_env.size())
    stats = train_env.get_statistics()
    print("The average instruction length of the dataset is %0.4f." %
          (stats['length']))
    print("The average action length of the dataset is %0.4f." %
          (stats['path']))

    # Setup the validation data
    val_envs = {
        split:
        (R2RBatch(feat_dict,
                  candidate_dict,
                  batch_size=args.batchSize,
                  splits=[split],
                  tokenizer=tok), Evaluation([split], featurized_scans, tok))
        for split in ['train', 'val_seen', 'val_unseen@133']
    }

    val_env, val_eval = val_envs['val_unseen@133']

    listner = Seq2SeqAgent(train_env, "", tok, args.maxAction)

    def filter_result():
        listner.env = val_env
        val_env.reset_epoch()
        listner.test(use_dropout=False, feedback='argmax')
        result = listner.get_results()
        score_summary, _ = val_eval.score(result)
        for metric, val in score_summary.items():
            if metric in ['success_rate']:
                return val

    listner.load(args.load)
    base_accu = (filter_result())
    print("BASE ACCU %0.4f" % base_accu)

    success = 0

    for data_id, datum in enumerate(aug_env.data):
        # Reload the param of the listener
        listner.load(args.load)
        train_env.reset_epoch(shuffle=True)

        listner.env = train_env

        # Train for the datum
        # iters = train_env.size() // train_env.batch_size
        iters = 10
        for i in range(iters):
            listner.env = train_env
            # train_env.reset(batch=([datum] * (train_env.batch_size // 2)), inject=True)
            train_env.reset(batch=[datum] * train_env.batch_size, inject=True)
            # train_env.reset()
            # train_env.reset()
            listner.train(1, feedback='sample', reset=False)
        # print("Iter %d, result %0.4f" % (i, filter_result()))
        now_accu = filter_result()
        if now_accu > base_accu:
            success += 1
        # print("RESULT %0.4f" % filter_result())
        print('Accu now %0.4f, success / total: %d / %d = %0.4f' %
              (now_accu, success, data_id + 1, success / (data_id + 1)))
Exemple #13
0
def arbiter():
    setup()
    # Create a batch training environment that will also preprocess text
    vocab = read_vocab(TRAIN_VOCAB)
    tok = Tokenizer(vocab=vocab, encoding_length=args.maxInput)

    if args.fast_train:
        feat_dict = read_img_features(features_fast)
    else:
        feat_dict = read_img_features(features)

    candidate_dict = utils.read_candidates(CANDIDATE_FEATURES)

    gt_train_env = R2RBatch(feat_dict,
                            candidate_dict,
                            batch_size=args.batchSize,
                            splits=['train'],
                            tokenizer=tok)
    gt_val_unseen_env = R2RBatch(feat_dict,
                                 candidate_dict,
                                 batch_size=args.batchSize,
                                 splits=['val_unseen'],
                                 tokenizer=tok)
    gt_val_seen_env = R2RBatch(feat_dict,
                               candidate_dict,
                               batch_size=args.batchSize,
                               splits=['val_seen'],
                               tokenizer=tok)

    # gt_train_env = R2RBatch(feat_dict, candidate_dict, batch_size=args.batchSize,
    #                         splits=['val_seen_half1'], tokenizer=tok)
    # gt_valid_env = R2RBatch(feat_dict, candidate_dict, batch_size=args.batchSize,
    #                         splits=['val_seen_half2'], tokenizer=tok)

    # Where to load the original data
    speaker_snap_name = "adam_drop6_correctsave" if args.speaker is None else args.speaker
    snapshot = "Iter_060000"
    aug_data_name = snapshot + "_FAKE"
    aug_data_path = "snap/speaker/long/%s/aug_data/%s.json" % (
        speaker_snap_name, aug_data_name)

    # Where to save the splitted data
    saved_path = os.path.join(log_dir, 'aug_data')
    os.makedirs(saved_path, exist_ok=True)
    gen_train_path = os.path.join(saved_path,
                                  "%s_%s.json" % (aug_data_name, 'train'))
    gen_valid_path = os.path.join(saved_path,
                                  "%s_%s.json" % (aug_data_name, 'valid'))
    gen_test_path = os.path.join(saved_path,
                                 "%s_%s.json" % (aug_data_name, 'test'))

    if args.train == 'arbiter':
        # Load the augmented data
        print("\nLoading the augmentation data from path %s" % aug_data_path)
        aug_data = json.load(open(aug_data_path))
        print("The size of the augmentation data is %d" % len(aug_data))

        # Shuffle and split the data.
        print("Creating the json files ...")
        random.seed(1)
        random.shuffle(aug_data)
        train_size = gt_train_env.size(
        ) * 1  # The size of training data should be much larger
        valid_size = gt_val_seen_env.size()  # The size of the test data
        print("valid size is %d " % valid_size)
        gen_train_data = aug_data[:train_size]
        gen_valid_data = aug_data[train_size:(train_size + valid_size)]
        gen_test_data = aug_data[train_size + valid_size:]

        # Create the json files
        json.dump(gen_train_data,
                  open(gen_train_path, 'w'),
                  sort_keys=True,
                  indent=4,
                  separators=(',', ': '))
        json.dump(gen_valid_data,
                  open(gen_valid_path, 'w'),
                  sort_keys=True,
                  indent=4,
                  separators=(',', ': '))
        json.dump(gen_test_data,
                  open(gen_test_path, 'w'),
                  sort_keys=True,
                  indent=4,
                  separators=(',', ': '))
        print("Finish dumping the json files\n")

        # Load augmentation Envs
        gen_train_path = "snap/speaker/long/%s/aug_data/%s.json" % (  # Train: unseen generate vs unseen gt
            speaker_snap_name, snapshot + "_val_unseen")
        aug_val_unseen_env = R2RBatch(feat_dict,
                                      candidate_dict,
                                      batch_size=args.batchSize,
                                      splits=[gen_train_path],
                                      tokenizer=tok)

        aug_fake_env = R2RBatch(feat_dict,
                                candidate_dict,
                                batch_size=args.batchSize,
                                splits=[gen_valid_path],
                                tokenizer=tok)

        gen_valid_path = "snap/speaker/long/%s/aug_data/%s.json" % (  # Valid:   seen generate vs   seen gt
            speaker_snap_name, snapshot + "_val_seen")
        aug_val_seen_env = R2RBatch(feat_dict,
                                    candidate_dict,
                                    batch_size=args.batchSize,
                                    splits=[gen_valid_path],
                                    tokenizer=tok)

        # gen_valid_path = "snap/speaker/long/%s/aug_data/%s.json" % (        # Valid:   seen generate vs   seen gt
        #     speaker_snap_name,
        #     snapshot + "_train"
        # )
        # aug_train_env = R2RBatch(feat_dict, candidate_dict, batch_size=args.batchSize, splits=[gen_valid_path], tokenizer=tok)
        # print("Loading the generated data from %s with size %d" % (aug_data_path, aug_train_env.size()))

        # Create Arbiter Envs
        arbiter_train_env = ArbiterBatch(gt_val_unseen_env,
                                         aug_val_unseen_env,
                                         args.batchSize // 2,
                                         args.batchSize // 2,
                                         feat_dict,
                                         candidate_dict,
                                         batch_size=args.batchSize,
                                         splits=[],
                                         tokenizer=tok)
        print("The size of Training data in Arbiter is %d" %
              arbiter_train_env.size())
        arbiter_valid_env = ArbiterBatch(gt_val_seen_env,
                                         aug_val_seen_env,
                                         args.batchSize // 2,
                                         args.batchSize // 2,
                                         feat_dict,
                                         candidate_dict,
                                         batch_size=args.batchSize,
                                         splits=[],
                                         tokenizer=tok)

        arbiter_valid_fake_env = ArbiterBatch(gt_val_seen_env,
                                              aug_fake_env,
                                              args.batchSize // 2,
                                              args.batchSize // 2,
                                              feat_dict,
                                              candidate_dict,
                                              batch_size=args.batchSize,
                                              splits=[],
                                              tokenizer=tok)
        # arbiter_valid_train_env = ArbiterBatch(gt_val_seen_env, aug_train_env, args.batchSize//2, args.batchSize//2,
        #                                        feat_dict, candidate_dict, batch_size=args.batchSize, splits=[], tokenizer=tok)

        print("The size of Validation data in Arbiter is %d" %
              arbiter_valid_env.size())
        train_arbiter(
            arbiter_train_env,
            tok,
            args.iters,
            val_envs={
                'train': arbiter_train_env,
                'val_seen': arbiter_valid_env,
                'valid_fake': arbiter_valid_fake_env,
                # 'valid_train': arbiter_valid_train_env,
            })
    if args.train == 'filterarbiter':
        # Load the augmentation test env
        # aug_test_env = R2RBatch(feat_dict, candidate_dict, batch_size=args.batchSize, splits=[gen_test_path], tokenizer=tok)
        aug_test_env = R2RBatch(feat_dict,
                                candidate_dict,
                                batch_size=args.batchSize,
                                splits=[aug_data_path],
                                tokenizer=tok)
        print("%d data is loaded to be filtered" % (aug_test_env.size()))
        filter_data = filter_arbiter(gt_val_seen_env, aug_test_env, tok)
        print("The size of the remaining data is %d" % len(filter_data))
        json.dump(filter_data,
                  open(
                      os.path.join(log_dir, "aug_data/%s_filter.json" %
                                   (aug_data_name)), 'w'),
                  sort_keys=True,
                  indent=4,
                  separators=(',', ': '))
Exemple #14
0
def train_val_augment(test_only=False):
    """
    Train the listener with the augmented data
    """
    setup()
    vocab = read_vocab(TRAIN_VOCAB)
    tok = Tokenizer(vocab=vocab, encoding_length=args.maxInput)

    feat_dict = read_img_features(features, test_only=test_only)

    if test_only:
        featurized_scans = None
        val_env_names = ['val_train_seen']
    else:
        featurized_scans = set(
            [key.split("_")[0] for key in list(feat_dict.keys())])
        val_env_names = ['val_train_seen', 'val_seen', 'val_unseen']

    if not args.test_obj:
        print('Loading compact pano-caffe object features ... (~3 seconds)')
        import pickle as pkl
        with open('img_features/objects/pano_object_class.pkl', 'rb') as f_pc:
            pano_caffe = pkl.load(f_pc)
    else:
        pano_caffe = None

    aug_path = args.aug

    # Create the training environment
    train_env = R2RBatch(feat_dict,
                         pano_caffe,
                         batch_size=args.batchSize,
                         splits=['train'],
                         tokenizer=tok)
    aug_env = R2RBatch(feat_dict,
                       pano_caffe,
                       batch_size=args.batchSize,
                       splits=[aug_path],
                       tokenizer=tok,
                       name='aug')

    stats = train_env.get_statistics()
    print("The training data_size is : %d" % train_env.size())
    print("The average instruction length of the dataset is %0.4f." %
          (stats['length']))
    print("The average action length of the dataset is %0.4f." %
          (stats['path']))
    stats = aug_env.get_statistics()
    print("The augmentation data size is %d" % aug_env.size())
    print("The average instruction length of the dataset is %0.4f." %
          (stats['length']))
    print("The average action length of the dataset is %0.4f." %
          (stats['path']))

    val_envs = {
        split:
        (R2RBatch(feat_dict,
                  pano_caffe,
                  batch_size=args.batchSize,
                  splits=[split],
                  tokenizer=tok), Evaluation([split], featurized_scans, tok))
        for split in val_env_names
    }

    train(train_env, tok, args.iters, val_envs=val_envs, aug_env=aug_env)
Exemple #15
0
def test():
    print('current directory', os.getcwd())
    os.chdir('..')
    print('current directory', os.getcwd())

    visible_gpu = "0"
    os.environ["CUDA_VISIBLE_DEVICES"] = visible_gpu

    args.name = 'SSM'
    args.attn = 'soft'
    args.train = 'listener'
    args.featdropout = 0.3
    args.angle_feat_size = 128
    args.feedback = 'sample'
    args.ml_weight = 0.2
    args.sub_out = 'max'
    args.dropout = 0.5
    args.optim = 'adam'
    args.lr = 3e-4
    args.iters = 80000
    args.maxAction = 35
    args.batchSize = 24
    args.target_batch_size = 24

    args.self_train = True
    args.aug = 'tasks/R2R/data/aug_paths.json'

    args.speaker = 'snap/speaker/state_dict/best_val_unseen_bleu'

    args.featdropout = 0.4
    args.iters = 200000

    if args.optim == 'rms':
        print("Optimizer: Using RMSProp")
        args.optimizer = torch.optim.RMSprop
    elif args.optim == 'adam':
        print("Optimizer: Using Adam")
        args.optimizer = torch.optim.Adam
    elif args.optim == 'sgd':
        print("Optimizer: sgd")
        args.optimizer = torch.optim.SGD

    log_dir = 'snap/%s' % args.name
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    logdir = '%s/eval' % log_dir
    writer = SummaryWriter(logdir=logdir)

    TRAIN_VOCAB = 'tasks/R2R/data/train_vocab.txt'
    TRAINVAL_VOCAB = 'tasks/R2R/data/trainval_vocab.txt'

    IMAGENET_FEATURES = 'img_features/ResNet-152-imagenet.tsv'

    if args.features == 'imagenet':
        features = IMAGENET_FEATURES

    if args.fast_train:
        name, ext = os.path.splitext(features)
        features = name + "-fast" + ext

    print(args)

    def setup():
        torch.manual_seed(1)
        torch.cuda.manual_seed(1)
        # Check for vocabs
        if not os.path.exists(TRAIN_VOCAB):
            write_vocab(build_vocab(splits=['train']), TRAIN_VOCAB)
        if not os.path.exists(TRAINVAL_VOCAB):
            write_vocab(
                build_vocab(splits=['train', 'val_seen', 'val_unseen']),
                TRAINVAL_VOCAB)

    #
    setup()

    vocab = read_vocab(TRAIN_VOCAB)
    tok = Tokenizer(vocab=vocab, encoding_length=args.maxInput)

    feat_dict = read_img_features(features)

    print('start extract keys...')
    featurized_scans = set(
        [key.split("_")[0] for key in list(feat_dict.keys())])
    print('keys extracted...')

    val_envs = {
        split: R2RBatch(feat_dict,
                        batch_size=args.batchSize,
                        splits=[split],
                        tokenizer=tok)
        for split in ['train', 'val_seen', 'val_unseen']
    }

    evaluators = {
        split: Evaluation([split], featurized_scans, tok)
        for split in ['train', 'val_seen', 'val_unseen']
    }

    learner = Learner(val_envs,
                      "",
                      tok,
                      args.maxAction,
                      process_num=2,
                      visible_gpu=visible_gpu)
    learner.eval_init()

    for i in range(0, 10000):
        ckpt = '%s/state_dict/Iter_%06d' % (log_dir, (i + 1) * 100)
        while not os.path.exists(ckpt):
            time.sleep(10)

        time.sleep(10)

        learner.load_eval(ckpt)

        results = learner.eval()
        loss_str = ''
        for key in results:
            evaluator = evaluators[key]
            result = results[key]

            score_summary, _ = evaluator.score(result)

            loss_str += ", %s \n" % key

            for metric, val in score_summary.items():
                loss_str += ', %s: %.3f' % (metric, val)
                writer.add_scalar('%s/%s' % (metric, key), val, (i + 1) * 100)

            loss_str += '\n'

        print(loss_str)
Exemple #16
0
def train_val_augment():
    """
    Train the listener with the augmented data
    """
    setup()

    # Create a batch training environment that will also preprocess text
    vocab = read_vocab(train_vocab)
    tok = Tokenizer(vocab=vocab, encoding_length=args.maxInput)

    # Load the env img features
    feat_dict = read_img_features(features)
    featurized_scans = set(
        [key.split("_")[0] for key in list(feat_dict.keys())])

    # Load the augmentation data
    if args.upload:
        aug_path = get_sync_dir(os.path.join(args.upload_path, args.aug))
    else:
        aug_path = os.path.join(args.R2R_Aux_path, args.aug)

    # Create the training environment

    # load object feature
    obj_s_feat = None
    if args.sparseObj:
        obj_s_feat = utils.read_obj_sparse_features(sparse_obj_feat,
                                                    args.objthr)

    obj_d_feat = None
    if args.denseObj:
        obj_d_feat = utils.read_obj_dense_features(dense_obj_feat1,
                                                   dense_obj_feat2, bbox,
                                                   sparse_obj_feat,
                                                   args.objthr)

    train_env = R2RBatch(feat_dict,
                         obj_d_feat=obj_d_feat,
                         obj_s_feat=obj_s_feat,
                         batch_size=args.batchSize,
                         splits=['train'],
                         tokenizer=tok)
    aug_env = R2RBatch(feat_dict,
                       obj_d_feat=obj_d_feat,
                       obj_s_feat=obj_s_feat,
                       batch_size=args.batchSize,
                       splits=[aug_path],
                       tokenizer=tok,
                       name='aug')

    # Printing out the statistics of the dataset
    stats = train_env.get_statistics()
    print("The training data_size is : %d" % train_env.size())
    print("The average instruction length of the dataset is %0.4f." %
          (stats['length']))
    print("The average action length of the dataset is %0.4f." %
          (stats['path']))
    stats = aug_env.get_statistics()
    print("The augmentation data size is %d" % aug_env.size())
    print("The average instruction length of the dataset is %0.4f." %
          (stats['length']))
    print("The average action length of the dataset is %0.4f." %
          (stats['path']))

    # Setup the validation data
    val_envs = {
        split:
        (R2RBatch(feat_dict,
                  batch_size=args.batchSize,
                  splits=[split],
                  tokenizer=tok), Evaluation([split], featurized_scans, tok))
        for split in ['train', 'val_seen', 'val_unseen']
    }

    val_envs = OrderedDict(((split, (R2RBatch(feat_dict,
                                              obj_d_feat=obj_d_feat,
                                              obj_s_feat=obj_s_feat,
                                              batch_size=args.batchSize,
                                              splits=[split],
                                              tokenizer=tok),
                                     Evaluation([split], featurized_scans,
                                                tok)))
                            for split in ['train', 'val_seen', 'val_unseen']))

    # Start training
    train(train_env, tok, args.iters, val_envs=val_envs, aug_env=aug_env)
Exemple #17
0
def train_val():
    ''' Train on the training set, and validate on seen and unseen splits. '''
    # args.fast_train = True
    setup()
    # Create a batch training environment that will also preprocess text
    vocab = read_vocab(train_vocab)
    tok = Tokenizer(vocab=vocab, encoding_length=args.maxInput)

    feat_dict = read_img_features(features)

    # load object feature
    obj_s_feat = None
    if args.sparseObj:
        obj_s_feat = utils.read_obj_sparse_features(sparse_obj_feat,
                                                    args.objthr)

    obj_d_feat = None
    if args.denseObj:
        obj_d_feat = utils.read_obj_dense_features(dense_obj_feat1,
                                                   dense_obj_feat2, bbox,
                                                   sparse_obj_feat,
                                                   args.objthr)

    featurized_scans = set(
        [key.split("_")[0] for key in list(feat_dict.keys())])

    train_env = R2RBatch(feat_dict,
                         obj_d_feat=obj_d_feat,
                         obj_s_feat=obj_s_feat,
                         batch_size=args.batchSize,
                         splits=['train'],
                         tokenizer=tok)

    val_env_names = ['val_unseen', 'val_seen']
    if args.submit:
        val_env_names.append('test')
    else:
        pass
        #val_env_names.append('train')

    if not args.beam:
        val_env_names.append("train")

    val_envs = OrderedDict(((split, (R2RBatch(feat_dict,
                                              obj_d_feat=obj_d_feat,
                                              obj_s_feat=obj_s_feat,
                                              batch_size=args.batchSize,
                                              splits=[split],
                                              tokenizer=tok),
                                     Evaluation([split], featurized_scans,
                                                tok)))
                            for split in val_env_names))

    if args.train == 'listener':
        train(train_env, tok, args.iters, val_envs=val_envs)
    elif args.train == 'validlistener':
        if args.beam:
            beam_valid(train_env, tok, val_envs=val_envs)
        else:
            valid(train_env, tok, val_envs=val_envs)
    elif args.train == 'speaker':
        train_speaker(train_env, tok, args.iters, val_envs=val_envs)
    elif args.train == 'validspeaker':
        valid_speaker(tok, val_envs)
    else:
        assert False
Exemple #18
0
from train import setup
from collections import defaultdict
import numpy as np
from matplotlib import pyplot as plt

setup()
args.angle_feat_size = 128
TRAIN_VOCAB = 'tasks/R2R/data/train_vocab.txt'
TRAINVAL_VOCAB = 'tasks/R2R/data/trainval_vocab.txt'

IMAGENET_FEATURES = 'img_features/ResNet-152-imagenet.tsv'
PLACE365_FEATURES = 'img_features/ResNet-152-places365.tsv'
features = IMAGENET_FEATURES
vocab = read_vocab(TRAIN_VOCAB)
tok = Tokenizer(vocab=vocab, encoding_length=80)
feat_dict = read_img_features(features)

train_env = R2RBatch(feat_dict, batch_size=64, splits=['train'], tokenizer=tok)
log_dir = "snap/speaker/state_dict/best_val_seen_bleu"
val_env_names = ['val_unseen', 'val_seen']
featurized_scans = set([key.split("_")[0] for key in list(feat_dict.keys())])

val_envs = OrderedDict(((split, (R2RBatch(feat_dict,
                                          batch_size=args.batchSize,
                                          splits=[split],
                                          tokenizer=tok),
                                 Evaluation([split], featurized_scans, tok)))
                        for split in val_env_names))

listner = Seq2SeqAgent(train_env, "", tok, 35)
speaker = speaker.Speaker(train_env, listner, tok)
Exemple #19
0
def train_val():
    ''' Train on the training set, and validate on seen and unseen splits. '''
    # args.fast_train = True
    setup()
    # Create a batch training environment that will also preprocess text
    vocab = read_vocab(TRAIN_VOCAB)

    tok = Tokenizer(vocab=vocab, encoding_length=args.maxInput)

    feat_dict = read_img_features(features)

    # load object feature
    obj_s_feat = None
    if args.sparseObj:
        print("Start loading the object sparse feature")
        start = time.time()
        obj_s_feat = np.load(sparse_obj_feat, allow_pickle=True).item()
        print(
            "Finish Loading the object sparse feature from %s in %0.4f seconds"
            % (sparse_obj_feat, time.time() - start))

    obj_d_feat = None
    if args.denseObj:
        print("Start loading the object dense feature")
        start = time.time()
        obj_d_feat1 = np.load(dense_obj_feat1, allow_pickle=True).item()
        obj_d_feat2 = np.load(dense_obj_feat2, allow_pickle=True).item()
        obj_d_feat = {**obj_d_feat1, **obj_d_feat2}
        print(
            "Finish Loading the dense object dense feature from %s and %s in %0.4f seconds"
            % (dense_obj_feat1, dense_obj_feat2, time.time() - start))

    featurized_scans = set(
        [key.split("_")[0] for key in list(feat_dict.keys())])

    train_env = R2RBatch(feat_dict,
                         obj_d_feat=obj_d_feat,
                         obj_s_feat=obj_s_feat,
                         batch_size=args.batchSize,
                         splits=['train'],
                         tokenizer=tok)
    from collections import OrderedDict

    val_env_names = ['val_unseen', 'val_seen']
    if args.submit:
        val_env_names.append('test')
    else:
        pass
        #val_env_names.append('train')

    if not args.beam:
        val_env_names.append("train")

    val_envs = OrderedDict(((split, (R2RBatch(feat_dict,
                                              obj_d_feat=obj_d_feat,
                                              obj_s_feat=obj_s_feat,
                                              batch_size=args.batchSize,
                                              splits=[split],
                                              tokenizer=tok),
                                     Evaluation([split], featurized_scans,
                                                tok)))
                            for split in val_env_names))

    if args.train == 'listener':
        train(train_env, tok, args.iters, val_envs=val_envs)
    elif args.train == 'validlistener':
        if args.beam:
            beam_valid(train_env, tok, val_envs=val_envs)
        else:
            valid(train_env, tok, val_envs=val_envs)
    elif args.train == 'speaker':
        train_speaker(train_env, tok, args.iters, val_envs=val_envs)
    elif args.train == 'validspeaker':
        valid_speaker(tok, val_envs)
    else:
        assert False
Exemple #20
0
def train():
    print('current directory', os.getcwd())
    os.chdir('..')
    print('current directory', os.getcwd())

    visible_gpu = "0,1,2,3"  # avaiable GPUs, GPU0 is for processing gradient accumulating
    os.environ["CUDA_VISIBLE_DEVICES"] = visible_gpu

    args.name = 'SSM'
    args.attn = 'soft'
    args.train = 'listener'
    args.featdropout = 0.4
    args.angle_feat_size = 128
    args.feedback = 'sample'
    args.ml_weight = 0.2
    args.sub_out = 'max'
    args.dropout = 0.5
    args.optim = 'rms'
    args.lr = 1e-4
    args.iters = 80000
    args.maxAction = 15
    args.batchSize = 16
    args.aug = 'tasks/R2R/data/aug_paths.json'
    args.self_train = True

    args.featdropout = 0.4
    args.iters = 200000

    if args.optim == 'rms':
        print("Optimizer: Using RMSProp")
        args.optimizer = torch.optim.RMSprop
    elif args.optim == 'adam':
        print("Optimizer: Using Adam")
        args.optimizer = torch.optim.Adam
    elif args.optim == 'sgd':
        print("Optimizer: sgd")
        args.optimizer = torch.optim.SGD

    log_dir = 'snap/%s' % args.name
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    TRAIN_VOCAB = 'tasks/R2R/data/train_vocab.txt'
    TRAINVAL_VOCAB = 'tasks/R2R/data/trainval_vocab.txt'

    IMAGENET_FEATURES = 'img_features/ResNet-152-imagenet.tsv'

    if args.features == 'imagenet':
        features = IMAGENET_FEATURES

    if args.fast_train:
        name, ext = os.path.splitext(features)
        features = name + "-fast" + ext

    print(args)

    def setup():
        torch.manual_seed(1)
        torch.cuda.manual_seed(1)
        # Check for vocabs
        if not os.path.exists(TRAIN_VOCAB):
            write_vocab(build_vocab(splits=['train']), TRAIN_VOCAB)
        if not os.path.exists(TRAINVAL_VOCAB):
            write_vocab(
                build_vocab(splits=['train', 'val_seen', 'val_unseen']),
                TRAINVAL_VOCAB)

    #
    setup()

    vocab = read_vocab(TRAIN_VOCAB)
    tok = Tokenizer(vocab=vocab, encoding_length=args.maxInput)

    feat_dict = read_img_features(features)

    # Create the training environment
    train_env = R2RBatch(feat_dict,
                         batch_size=args.batchSize,
                         splits=['train'],
                         tokenizer=tok)
    aug_env = R2RBatch(feat_dict,
                       batch_size=args.batchSize,
                       splits=[args.aug],
                       tokenizer=tok)

    train_env = {'train': train_env, 'aug': aug_env}

    load_path = None

    torch.autograd.set_detect_anomaly(True)

    learner = Learner(train_env,
                      "",
                      tok,
                      args.maxAction,
                      process_num=4,
                      max_node=17,
                      visible_gpu=visible_gpu)

    if load_path is not None:
        print('load checkpoint from:', load_path)
        learner.load(load_path)

    learner.train()
Exemple #21
0
def train_val_augment():
    """
    Train the listener with the augmented data
    """
    setup()
    # Create a batch training environment that will also preprocess text
    vocab = read_vocab(TRAIN_VOCAB)
    tok = Tokenizer(vocab=vocab, encoding_length=args.maxInput)

    if args.fast_train:
        feat_dict = read_img_features(features_fast)
    else:
        feat_dict = read_img_features(features)
    candidate_dict = utils.read_candidates(CANDIDATE_FEATURES)
    featurized_scans = set(
        [key.split("_")[0] for key in list(feat_dict.keys())])

    # Load the augmentation data
    if args.aug is None:  # If aug is specified, load the "aug"
        speaker_snap_name = "adam_drop6_correctsave"
        print("Loading from %s" % speaker_snap_name)
        aug_path = "snap/speaker/long/%s/aug_data/best_val_unseen_loss.json" % speaker_snap_name
    else:  # Load the path from args
        aug_path = args.aug

    # The dataset used in training
    splits = [aug_path, 'train'] if args.combineAug else [aug_path]

    # Create the training environment
    if args.half_half:
        assert args.aug is not None
        gt_env = R2RBatch(feat_dict,
                          candidate_dict,
                          batch_size=args.batchSize,
                          splits=['train'],
                          tokenizer=tok)
        aug_env = R2RBatch(feat_dict,
                           candidate_dict,
                           batch_size=args.batchSize,
                           splits=[aug_path],
                           tokenizer=tok)
        train_env = ArbiterBatch(gt_env,
                                 aug_env,
                                 args.batchSize // 2,
                                 args.batchSize // 2,
                                 feat_dict,
                                 candidate_dict,
                                 batch_size=args.batchSize,
                                 splits=[],
                                 tokenizer=tok)
    else:
        train_env = R2RBatch(feat_dict,
                             candidate_dict,
                             batch_size=args.batchSize,
                             splits=splits,
                             tokenizer=tok)

    print("The augmented data_size is : %d" % train_env.size())
    # stats = train_env.get_statistics()
    # print("The average instruction length of the dataset is %0.4f." % (stats['length']))
    # print("The average action length of the dataset is %0.4f." % (stats['path']))

    # Setup the validation data
    val_envs = {
        split:
        (R2RBatch(feat_dict,
                  candidate_dict,
                  batch_size=args.batchSize,
                  splits=[split],
                  tokenizer=tok), Evaluation([split], featurized_scans, tok))
        for split in ['train', 'val_seen', 'val_unseen']
    }

    # Start training
    train(train_env, tok, args.iters, val_envs=val_envs)