コード例 #1
0
def train_val_augment():
    """
    Train the listener with the augmented data
    """
    setup()

    # Create a batch training environment that will also preprocess text
    vocab = read_vocab(train_vocab)
    tok = Tokenizer(vocab=vocab, encoding_length=args.maxInput)

    # Load the env img features
    feat_dict = read_img_features(features)
    featurized_scans = set(
        [key.split("_")[0] for key in list(feat_dict.keys())])

    # Load the augmentation data
    if args.upload:
        aug_path = get_sync_dir(os.path.join(args.upload_path, args.aug))
    else:
        aux_path = os.path.join(args.R2R_Aux_path, args.aug)

    # Create the training environment
    train_env = R2RBatch(feat_dict,
                         batch_size=args.batchSize,
                         splits=['train'],
                         tokenizer=tok)
    aug_env = R2RBatch(feat_dict,
                       batch_size=args.batchSize,
                       splits=[aug_path],
                       tokenizer=tok,
                       name='aug')

    # Printing out the statistics of the dataset
    stats = train_env.get_statistics()
    print("The training data_size is : %d" % train_env.size())
    print("The average instruction length of the dataset is %0.4f." %
          (stats['length']))
    print("The average action length of the dataset is %0.4f." %
          (stats['path']))
    stats = aug_env.get_statistics()
    print("The augmentation data size is %d" % aug_env.size())
    print("The average instruction length of the dataset is %0.4f." %
          (stats['length']))
    print("The average action length of the dataset is %0.4f." %
          (stats['path']))

    # Setup the validation data
    val_envs = {
        split:
        (R2RBatch(feat_dict,
                  batch_size=args.batchSize,
                  splits=[split],
                  tokenizer=tok), Evaluation([split], featurized_scans, tok))
        for split in ['train', 'val_seen', 'val_unseen']
    }

    # Start training
    train(train_env, tok, args.iters, val_envs=val_envs, aug_env=aug_env)
コード例 #2
0
ファイル: train.py プロジェクト: ZhuFengdaaa/MG-AuxRN
def train(train_env, tok, n_iters, log_every=100, val_envs={}, aug_env=None):
    writer = SummaryWriter(logdir=log_dir)
    listner = Seq2SeqAgent(train_env, "", tok, args.maxAction)

    speaker = None
    if args.self_train:
        speaker = Speaker(train_env, listner, tok)
        if args.speaker is not None:
            if args.upload:
                print("Load the speaker from %s." % args.speaker)
                speaker.load(
                    get_sync_dir(os.path.join(args.upload_path, args.speaker)))
            else:
                print("Load the speaker from %s." % args.speaker)
                speaker.load(os.path.join(args.R2R_Aux_path, args.speaker))

    start_iter = 0
    if args.load is not None:
        if args.upload:
            refs_paths = get_outputs_refs_paths()['experiments'][0]
            print(refs_paths)
            load_model = os.path.join(refs_paths, args.load)
            print(load_model)
            print("LOAD THE listener from %s" % load_model)
            start_iter = listner.load(load_model)
        else:
            print("LOAD THE listener from %s" % args.load)
            start_iter = listner.load(
                os.path.join(args.R2R_Aux_path, args.load))

    start = time.time()

    best_val = {
        'val_seen': {
            "accu": 0.,
            "state": "",
            'update': False
        },
        'val_unseen': {
            "accu": 0.,
            "state": "",
            'update': False
        }
    }
    if args.fast_train:
        log_every = 40
    for idx in range(start_iter, start_iter + n_iters, log_every):
        listner.logs = defaultdict(list)
        interval = min(log_every, start_iter + n_iters - idx)
        iter = idx + interval

        # Train for log_every interval
        if aug_env is None:  # The default training process
            listner.env = train_env
            listner.train(interval,
                          feedback=feedback_method)  # Train interval iters
        else:
            if args.accumulate_grad:
                for _ in range(interval // 2):
                    listner.zero_grad()
                    listner.env = train_env

                    # Train with GT data
                    args.ml_weight = 0.2
                    listner.accumulate_gradient(feedback_method)
                    listner.env = aug_env

                    # Train with Back Translation
                    args.ml_weight = 0.6  # Sem-Configuration
                    listner.accumulate_gradient(feedback_method,
                                                speaker=speaker)
                    listner.optim_step()
            else:
                for _ in range(interval // 2):
                    # Train with GT data
                    listner.env = train_env
                    args.ml_weight = 0.2
                    listner.train(1, feedback=feedback_method)

                    # Train with Back Translation
                    listner.env = aug_env
                    args.ml_weight = 0.6
                    listner.train(1, feedback=feedback_method, speaker=speaker)

        # Log the training stats to tensorboard
        total = max(sum(listner.logs['total']), 1)
        # import pdb; pdb.set_trace() # length_rl == length_ml ? entropy length
        assert (max(len(listner.logs['rl_loss']),
                    1) == max(len(listner.logs['ml_loss']), 1))
        max_rl_length = max(len(listner.logs['critic_loss']), 1)
        log_length = max(len(listner.logs['rl_loss']), 1)
        rl_loss = sum(listner.logs['rl_loss']) / log_length
        ml_loss = sum(listner.logs['ml_loss']) / log_length
        critic_loss = sum(listner.logs['critic_loss']
                          ) / log_length  #/ length / args.batchSize
        spe_loss = sum(listner.logs['spe_loss']) / log_length
        pro_loss = sum(listner.logs['pro_loss']) / log_length
        mat_loss = sum(listner.logs['mat_loss']) / log_length
        fea_loss = sum(listner.logs['fea_loss']) / log_length
        ang_loss = sum(listner.logs['ang_loss']) / log_length
        entropy = sum(
            listner.logs['entropy']) / log_length  #/ length / args.batchSize
        predict_loss = sum(listner.logs['us_loss']) / log_length
        writer.add_scalar("loss/rl_loss", rl_loss, idx)
        writer.add_scalar("loss/ml_loss", ml_loss, idx)
        writer.add_scalar("policy_entropy", entropy, idx)
        writer.add_scalar("loss/spe_loss", spe_loss, idx)
        writer.add_scalar("loss/pro_loss", pro_loss, idx)
        writer.add_scalar("loss/mat_loss", mat_loss, idx)
        writer.add_scalar("loss/fea_loss", fea_loss, idx)
        writer.add_scalar("loss/ang_loss", ang_loss, idx)
        writer.add_scalar("total_actions", total, idx)
        writer.add_scalar("max_rl_length", max_rl_length, idx)
        writer.add_scalar("loss/critic", critic_loss, idx)
        writer.add_scalar("loss/unsupervised", predict_loss, idx)
        print("total_actions", total)
        print("max_rl_length", max_rl_length)

        # Run validation
        loss_str = ""
        for env_name, (env, evaluator) in val_envs.items():
            listner.env = env

            # Get validation loss under the same conditions as training
            iters = None if args.fast_train or env_name != 'train' else 20  # 20 * 64 = 1280

            # Get validation distance from goal under test evaluation conditions
            listner.test(use_dropout=False, feedback='argmax', iters=iters)
            result = listner.get_results()
            score_summary, _ = evaluator.score(result)
            loss_str += "%s " % env_name
            for metric, val in score_summary.items():
                if metric in ['success_rate']:
                    loss_str += ', %s: %.4f' % (metric, val)
                    writer.add_scalar("%s/accuracy" % env_name, val, idx)
                    if env_name in best_val:
                        if val > best_val[env_name]['accu']:
                            best_val[env_name]['accu'] = val
                            best_val[env_name]['update'] = True
                if metric in ['spl']:
                    writer.add_scalar("%s/spl" % env_name, val, idx)
                    loss_str += ', %s: %.4f' % (metric, val)
            loss_str += '\n'
        loss_str += '\n'

        for env_name in best_val:
            if best_val[env_name]['update']:
                best_val[env_name]['state'] = 'Iter %d \n%s' % (iter, loss_str)
                best_val[env_name]['update'] = False
                file_dir = os.path.join(output_dir, "snap", args.name,
                                        "state_dict", "best_%s" % (env_name))
                listner.save(idx, file_dir)
        print(('%s (%d %d%%) \n%s' % (timeSince(start,
                                                float(iter) / n_iters), iter,
                                      float(iter) / n_iters * 100, loss_str)))

        if iter % 1000 == 0:
            print("BEST RESULT TILL NOW")
            for env_name in best_val:
                print(env_name, best_val[env_name]['state'])

        if iter % args.save_iter == 0:
            file_dir = os.path.join(output_dir, "snap", args.name,
                                    "state_dict", "Iter_%06d" % (iter))
            listner.save(idx, file_dir)
コード例 #3
0
ファイル: train.py プロジェクト: ZhuFengdaaa/MG-AuxRN
from polyaxon_client.tracking import get_outputs_refs_paths
if args.train == 'validlistener' and args.upload:
    refs_paths = get_outputs_refs_paths()['experiments'][0]
    print(refs_paths)
    load_model = os.path.join(refs_paths, args.load)
    print(load_model)

import warnings
warnings.filterwarnings("ignore")

from tensorboardX import SummaryWriter
from polyaxon_client.tracking import get_outputs_path

if args.upload:
    train_vocab = get_sync_dir(os.path.join(args.upload_path,
                                            args.TRAIN_VOCAB))
    trainval_vocab = get_sync_dir(
        os.path.join(args.upload_path, args.TRAINVAL_VOCAB))
    features = get_sync_dir(
        os.path.join(args.upload_path, args.IMAGENET_FEATURES))
    output_dir = get_outputs_path()
    log_dir = os.path.join(output_dir, "snap", args.name)
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    sparse_obj_feat = get_sync_dir(
        os.path.join(args.upload_path, args.SPARSE_OBJ_FEATURES))
    dense_obj_feat1 = get_sync_dir(
        os.path.join(args.upload_path, args.DENSE_OBJ_FEATURES1))
    dense_obj_feat2 = get_sync_dir(
        os.path.join(args.upload_path, args.DENSE_OBJ_FEATURES2))
    bbox = get_sync_dir(os.path.join(args.upload_path, args.BBOX_FEATURES))
コード例 #4
0
ファイル: train.py プロジェクト: yestinl/R2R-Aux-Obj
TRAIN_VOCAB = 'tasks/R2R/data/train_vocab.txt'
TRAINVAL_VOCAB = 'tasks/R2R/data/trainval_vocab.txt'

IMAGENET_FEATURES = 'lyx/img_features/ResNet-152-imagenet.tsv'
PLACE365_FEATURES = 'lyx/img_features/ResNet-152-places365.tsv'

SPARSE_OBJ_FEATURES = 'lyx/obj_features/%s/panorama_objs_Features_nms_%s.npy' % (
    args.objdir, args.objdir)
DENSE_OBJ_FEATURES1 = 'lyx/obj_features/%s/panorama_objs_DenseFeatures_nms1_%s.npy' % (
    args.objdir, args.objdir)
DENSE_OBJ_FEATURES2 = 'lyx/obj_features/%s/panorama_objs_DenseFeatures_nms2_%s.npy' % (
    args.objdir, args.objdir)

if args.features == 'imagenet':
    features = get_sync_dir(IMAGENET_FEATURES)
    sparse_obj_feat = get_sync_dir(SPARSE_OBJ_FEATURES)
    dense_obj_feat1 = get_sync_dir(DENSE_OBJ_FEATURES1)
    dense_obj_feat2 = get_sync_dir(DENSE_OBJ_FEATURES2)

if args.fast_train:
    name, ext = os.path.splitext(features)
    features = name + "-fast" + ext

feedback_method = args.feedback  # teacher or sample

print(args)


def train_speaker(train_env, tok, n_iters, log_every=500, val_envs={}):
    writer = SummaryWriter(logdir=log_dir)
コード例 #5
0
from polyaxon_client.tracking import get_outputs_refs_paths
if args.train == 'validlistener' and args.upload:
    refs_paths = get_outputs_refs_paths()['experiments'][0]
    print(refs_paths)
    load_model = os.path.join(refs_paths, args.load)
    print(load_model)

import warnings
warnings.filterwarnings("ignore")

from tensorboardX import SummaryWriter
from polyaxon_client.tracking import get_outputs_path

if args.upload:
    train_vocab = get_sync_dir(os.path.join(args.upload_path,
                                            args.TRAIN_VOCAB))
    trainval_vocab = get_sync_dir(
        os.path.join(args.upload_path, args.TRAINVAL_VOCAB))
    features = get_sync_dir(
        os.path.join(args.upload_path, args.IMAGENET_FEATURES))
    output_dir = get_outputs_path()
    log_dir = os.path.join(output_dir, "snap", args.name)
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    # sparse_obj_feat = get_sync_dir(os.path.join(args.upload_path, args.SPARSE_OBJ_FEATURES))
    # dense_obj_feat1 = get_sync_dir(os.path.join(args.upload_path, args.DENSE_OBJ_FEATURES1))
    # dense_obj_feat2 = get_sync_dir(os.path.join(args.upload_path, args.DENSE_OBJ_FEATURES2))
    # bbox = get_sync_dir(os.path.join(args.upload_path, args.BBOX_FEATURES))

else:
    train_vocab = os.path.join(args.R2R_Aux_path, args.TRAIN_VOCAB)