Ejemplo n.º 1
0
def valid_speaker(tok, val_envs):
    import tqdm
    listner = Seq2SeqAgent(None, "", tok, args.maxAction)
    speaker = Speaker(None, listner, tok)
    speaker.load(os.path.join(log_dir, 'state_dict', 'best_val_seen_bleu'))
    # speaker.load(os.path.join(log_dir, 'state_dict', 'best_val_unseen_loss'))

    for args.beam in [False, True]:
        print("Using Beam Search %s" % args.beam)
        for env_name, (env, evaluator) in val_envs.items():
            if env_name == 'train':
                continue
            print("............ Evaluating %s ............." % env_name)
            speaker.env = env
            path2inst, loss, word_accu, sent_accu = speaker.valid(
                beam=args.beam, wrapper=tqdm.tqdm)
            path_id = next(iter(path2inst.keys()))
            print("Inference: ", tok.decode_sentence(path2inst[path_id]))
            print("GT: ", evaluator.gt[path_id]['instructions'])
            bleu_score, precisions, _ = evaluator.bleu_score(path2inst)
            print(
                "Bleu, Loss, Word_Accu, Sent_Accu for %s is: %0.4f, %0.4f, %0.4f, %0.4f"
                % (env_name, bleu_score, loss, word_accu, sent_accu))
            print(
                "Bleu 1: %0.4f Bleu 2: %0.4f, Bleu 3 :%0.4f,  Bleu 4: %0.4f" %
                tuple(precisions))
            print("Average Length %0.4f" % utils.average_length(path2inst))
Ejemplo n.º 2
0
def infer_speaker(env, tok):
    import tqdm
    from utils import load_datasets
    listner = Seq2SeqAgent(env, "", tok, args.maxAction)
    speaker = Speaker(env, listner, tok)
    speaker.load(args.load)

    dataset = load_datasets(env.splits)
    key_map = {}
    for i, item in enumerate(dataset):
        key_map[item["path_id"]] = i
    path2inst = speaker.get_insts(wrapper=tqdm.tqdm)
    for path_id in path2inst.keys():
        speaker_pred = tok.decode_sentence(path2inst[path_id])
        dataset[key_map[path_id]]['instructions'] = [speaker_pred]

    with open("tasks/R2R/data/aug_paths_unseen_infer.json", "w") as f:
        json.dump(dataset, f, indent=4, sort_keys=True)
Ejemplo n.º 3
0
def valid_speaker(train_env, tok, val_envs):
    import tqdm
    listner = Seq2SeqAgent(train_env, "", tok, args.maxAction)
    speaker = Speaker(train_env, listner, tok)
    speaker.load(args.load)

    for env_name, (env, evaluator) in val_envs.items():
        if env_name == 'train':
            continue
        print("............ Evaluating %s ............." % env_name)
        speaker.env = env
        path2inst, loss, word_accu, sent_accu = speaker.valid(wrapper=tqdm.tqdm)
        path_id = next(iter(path2inst.keys()))
        print("Inference: ", tok.decode_sentence(path2inst[path_id]))
        print("GT: ", evaluator.gt[str(path_id)]['instructions'])
        bleu_score, precisions = evaluator.bleu_score(path2inst)

        print(len(env.data), len(path2inst.keys()))
        import pdb; pdb.set_trace()
Ejemplo n.º 4
0
def valid_speaker(tok, val_envs):
    import tqdm
    listner = Seq2SeqAgent(None, "", tok, args.maxAction)
    speaker = Speaker(None, listner, tok)
    speaker.load(args.load)

    for env_name, (env, evaluator) in val_envs.items():
        if env_name == 'train':
            continue
        print("............ Evaluating %s ............." % env_name)
        speaker.env = env
        path2inst, loss, word_accu, sent_accu = speaker.valid(wrapper=tqdm.tqdm)
        path_id = next(iter(path2inst.keys()))
        print("Inference: ", tok.decode_sentence(path2inst[path_id]))
        print("GT: ", evaluator.gt[path_id]['instructions'])
        pathXinst = list(path2inst.items())
        name2score = evaluator.lang_eval(pathXinst, no_metrics={'METEOR'})
        score_string = " "
        for score_name, score in name2score.items():
            score_string += "%s_%s: %0.4f " % (env_name, score_name, score)
        print("For env %s" % env_name)
        print(score_string)
        print("Average Length %0.4f" % utils.average_length(path2inst))
Ejemplo n.º 5
0
def beam_valid(train_env, tok, val_envs={}):
    listener = Seq2SeqAgent(train_env, "", tok, args.maxAction)

    speaker = Speaker(train_env, listener, tok)
    if args.speaker is not None:
        print("Load the speaker from %s." % args.speaker)
        speaker.load(args.speaker)

    print("Loaded the listener model at iter % d" % listener.load(args.load))

    final_log = ""
    for env_name, (env, evaluator) in val_envs.items():
        listener.logs = defaultdict(list)
        listener.env = env

        listener.beam_search_test(speaker)
        results = listener.results

        def cal_score(x, alpha, avg_speaker, avg_listener):
            speaker_score = sum(x["speaker_scores"]) * alpha
            if avg_speaker:
                speaker_score /= len(x["speaker_scores"])
            # normalizer = sum(math.log(k) for k in x['listener_actions'])
            normalizer = 0.
            listener_score = (sum(x["listener_scores"]) + normalizer) * (1 -
                                                                         alpha)
            if avg_listener:
                listener_score /= len(x["listener_scores"])
            return speaker_score + listener_score

        if args.param_search:
            # Search for the best speaker / listener ratio
            interval = 0.01
            logs = []
            for avg_speaker in [False, True]:
                for avg_listener in [False, True]:
                    for alpha in np.arange(0, 1 + interval, interval):
                        result_for_eval = []
                        for key in results:
                            result_for_eval.append({
                                "instr_id":
                                key,
                                "trajectory":
                                max(results[key]['paths'],
                                    key=lambda x: cal_score(
                                        x, alpha, avg_speaker, avg_listener))
                                ['trajectory']
                            })
                        score_summary, _ = evaluator.score(result_for_eval)
                        for metric, val in score_summary.items():
                            if metric in ['success_rate']:
                                print(
                                    "Avg speaker %s, Avg listener %s, For the speaker weight %0.4f, the result is %0.4f"
                                    % (avg_speaker, avg_listener, alpha, val))
                                logs.append(
                                    (avg_speaker, avg_listener, alpha, val))
            tmp_result = "Env Name %s\n" % (env_name) + \
                    "Avg speaker %s, Avg listener %s, For the speaker weight %0.4f, the result is %0.4f\n" % max(logs, key=lambda x: x[3])
            print(tmp_result)
            # print("Env Name %s" % (env_name))
            # print("Avg speaker %s, Avg listener %s, For the speaker weight %0.4f, the result is %0.4f" %
            #       max(logs, key=lambda x: x[3]))
            final_log += tmp_result
            print()
        else:
            avg_speaker = True
            avg_listener = True
            alpha = args.alpha

            result_for_eval = []
            for key in results:
                result_for_eval.append({
                    "instr_id": key,
                    "trajectory": [(vp, 0, 0) for vp in results[key]['dijk_path']] + \
                                  max(results[key]['paths'],
                                   key=lambda x: cal_score(x, alpha, avg_speaker, avg_listener)
                                  )['trajectory']
                })
            # result_for_eval = utils.add_exploration(result_for_eval)
            score_summary, _ = evaluator.score(result_for_eval)

            if env_name != 'test':
                loss_str = "Env Name: %s" % env_name
                for metric, val in score_summary.items():
                    if metric in ['success_rate']:
                        print(
                            "Avg speaker %s, Avg listener %s, For the speaker weight %0.4f, the result is %0.4f"
                            % (avg_speaker, avg_listener, alpha, val))
                    loss_str += ",%s: %0.4f " % (metric, val)
                print(loss_str)
            print()

            if args.submit:
                json.dump(result_for_eval,
                          open(
                              os.path.join(log_dir,
                                           "submit_%s.json" % env_name), 'w'),
                          sort_keys=True,
                          indent=4,
                          separators=(',', ': '))
    print(final_log)
Ejemplo n.º 6
0
def train(train_env, tok, n_iters, log_every=100, val_envs={}, aug_env=None):
    writer = SummaryWriter(logdir=log_dir)
    listner = Seq2SeqAgent(train_env, "", tok, args.maxAction)

    speaker = None
    if args.self_train:
        speaker = Speaker(train_env, listner, tok)
        if args.speaker is not None:
            print("Load the speaker from %s." % args.speaker)
            speaker.load(args.speaker)

    start_iter = 0
    if args.load is not None:
        print("LOAD THE listener from %s" % args.load)
        start_iter = listner.load(os.path.join(args.load))

    start = time.time()

    best_val = {
        'val_seen': {
            "accu": 0.,
            "state": "",
            'update': False
        },
        'val_unseen': {
            "accu": 0.,
            "state": "",
            'update': False
        }
    }
    if args.fast_train:
        log_every = 40
    for idx in range(start_iter, start_iter + n_iters, log_every):
        listner.logs = defaultdict(list)
        interval = min(log_every, n_iters - idx)
        iter = idx + interval

        # Train for log_every interval
        if aug_env is None:  # The default training process
            listner.env = train_env
            listner.train(interval,
                          feedback=feedback_method)  # Train interval iters
        else:
            if args.accumulate_grad:
                for _ in range(interval // 2):
                    listner.zero_grad()
                    listner.env = train_env

                    # Train with GT data
                    args.ml_weight = 0.2
                    listner.accumulate_gradient(feedback_method)
                    listner.env = aug_env

                    # Train with Back Translation
                    args.ml_weight = 0.6  # Sem-Configuration
                    listner.accumulate_gradient(feedback_method,
                                                speaker=speaker)
                    listner.optim_step()
            else:
                for _ in range(interval // 2):
                    # Train with GT data
                    listner.env = train_env
                    args.ml_weight = 0.2
                    listner.train(1, feedback=feedback_method)

                    # Train with Back Translation
                    listner.env = aug_env
                    args.ml_weight = 0.6
                    listner.train(1, feedback=feedback_method, speaker=speaker)

        # Log the training stats to tensorboard
        total = max(sum(listner.logs['total']), 1)
        length = max(len(listner.logs['critic_loss']), 1)
        critic_loss = sum(
            listner.logs['critic_loss']) / total  #/ length / args.batchSize
        entropy = sum(
            listner.logs['entropy']) / total  #/ length / args.batchSize
        predict_loss = sum(listner.logs['us_loss']) / max(
            len(listner.logs['us_loss']), 1)
        writer.add_scalar("loss/critic", critic_loss, idx)
        writer.add_scalar("policy_entropy", entropy, idx)
        writer.add_scalar("loss/unsupervised", predict_loss, idx)
        writer.add_scalar("total_actions", total, idx)
        writer.add_scalar("max_length", length, idx)
        print("total_actions", total)
        print("max_length", length)

        # Run validation
        loss_str = ""
        for env_name, (env, evaluator) in val_envs.items():
            listner.env = env

            # Get validation loss under the same conditions as training
            iters = None if args.fast_train or env_name != 'train' else 20  # 20 * 64 = 1280

            # Get validation distance from goal under test evaluation conditions
            listner.test(use_dropout=False, feedback='argmax', iters=iters)
            result = listner.get_results()
            score_summary, _ = evaluator.score(result)
            loss_str += ", %s " % env_name
            for metric, val in score_summary.items():
                if metric in ['success_rate']:
                    writer.add_scalar("accuracy/%s" % env_name, val, idx)
                    if env_name in best_val:
                        if val > best_val[env_name]['accu']:
                            best_val[env_name]['accu'] = val
                            best_val[env_name]['update'] = True
                loss_str += ', %s: %.3f' % (metric, val)

        for env_name in best_val:
            if best_val[env_name]['update']:
                best_val[env_name]['state'] = 'Iter %d %s' % (iter, loss_str)
                best_val[env_name]['update'] = False
                listner.save(
                    idx,
                    os.path.join("snap", args.name, "state_dict",
                                 "best_%s" % (env_name)))

        print(('%s (%d %d%%) %s' % (timeSince(start,
                                              float(iter) / n_iters), iter,
                                    float(iter) / n_iters * 100, loss_str)))

        if iter % 1000 == 0:
            print("BEST RESULT TILL NOW")
            for env_name in best_val:
                print(env_name, best_val[env_name]['state'])

        if iter % 50000 == 0:
            listner.save(
                idx,
                os.path.join("snap", args.name, "state_dict",
                             "Iter_%06d" % (iter)))

    listner.save(
        idx,
        os.path.join("snap", args.name, "state_dict", "LAST_iter%d" % (idx)))
Ejemplo n.º 7
0
def create_augment_data():
    setup()

    # Create a batch training environment that will also preprocess text
    vocab = read_vocab(TRAIN_VOCAB)
    tok = Tokenizer(vocab=vocab, encoding_length=args.maxInput)

    # Load features
    feat_dict = read_img_features(features)
    candidate_dict = utils.read_candidates(CANDIDATE_FEATURES)

    # The datasets to be augmented
    print("Start to augment the data")
    aug_envs = []
    # aug_envs.append(
    #     R2RBatch(
    #         feat_dict, candidate_dict, batch_size=args.batchSize, splits=['train'], tokenizer=tok
    #     )
    # )
    # aug_envs.append(
    #     SemiBatch(False, 'tasks/R2R/data/all_paths_46_removetrain.json',
    #         feat_dict, candidate_dict, batch_size=args.batchSize, splits=['train', 'val_seen'], tokenizer=tok)
    # )
    aug_envs.append(
        SemiBatch(False,
                  'tasks/R2R/data/all_paths_46_removevalunseen.json',
                  "unseen",
                  feat_dict,
                  candidate_dict,
                  batch_size=args.batchSize,
                  splits=['val_unseen'],
                  tokenizer=tok))
    aug_envs.append(
        SemiBatch(False,
                  'tasks/R2R/data/all_paths_46_removetest.json',
                  "test",
                  feat_dict,
                  candidate_dict,
                  batch_size=args.batchSize,
                  splits=['test'],
                  tokenizer=tok))
    # aug_envs.append(
    #     R2RBatch(
    #         feat_dict, candidate_dict, batch_size=args.batchSize, splits=['val_seen'], tokenizer=tok
    #     )
    # )
    # aug_envs.append(
    #     R2RBatch(
    #         feat_dict, candidate_dict, batch_size=args.batchSize, splits=['val_unseen'], tokenizer=tok
    #     )
    # )

    for snapshot in os.listdir(os.path.join(log_dir, 'state_dict')):
        # if snapshot != "best_val_unseen_bleu":  # Select a particular snapshot to process. (O/w, it will make for every snapshot)
        if snapshot != "best_val_unseen_bleu":
            continue

        # Create Speaker
        listner = Seq2SeqAgent(aug_envs[0], "", tok, args.maxAction)
        speaker = Speaker(aug_envs[0], listner, tok)

        # Load Weight
        load_iter = speaker.load(os.path.join(log_dir, 'state_dict', snapshot))
        print("Load from iter %d" % (load_iter))

        # Augment the env from aug_envs
        for aug_env in aug_envs:
            speaker.env = aug_env

            # Create the aug data
            import tqdm
            path2inst = speaker.get_insts(beam=args.beam, wrapper=tqdm.tqdm)
            data = []
            for datum in aug_env.fake_data:
                datum = datum.copy()
                path_id = datum['path_id']
                if path_id in path2inst:
                    datum['instructions'] = [
                        tok.decode_sentence(path2inst[path_id])
                    ]
                    datum.pop('instr_encoding')  # Remove Redundant keys
                    datum.pop('instr_id')
                    data.append(datum)

            print("Totally, %d data has been generated for snapshot %s." %
                  (len(data), snapshot))
            print("Average Length %0.4f" % utils.average_length(path2inst))
            print(datum)  # Print a Sample

            # Save the data
            import json
            os.makedirs(os.path.join(log_dir, 'aug_data'), exist_ok=True)
            beam_tag = "_beam" if args.beam else ""
            json.dump(data,
                      open(
                          os.path.join(
                              log_dir, 'aug_data', '%s_%s%s.json' %
                              (snapshot, aug_env.name, beam_tag)), 'w'),
                      sort_keys=True,
                      indent=4,
                      separators=(',', ': '))
Ejemplo n.º 8
0
def train(train_env, tok, n_iters, log_every=100, val_envs={}, aug_env=None):
    writer = SummaryWriter(logdir=log_dir)
    listner = Seq2SeqAgent(train_env, "", tok, args.maxAction)

    speaker = None
    if args.self_train:
        speaker = Speaker(train_env, listner, tok)
        if args.speaker is not None:
            if args.upload:
                print("Load the speaker from %s." % args.speaker)
                speaker.load(
                    get_sync_dir(os.path.join(args.upload_path, args.speaker)))
            else:
                print("Load the speaker from %s." % args.speaker)
                speaker.load(os.path.join(args.R2R_Aux_path, args.speaker))

    start_iter = 0
    if args.load is not None:
        if args.upload:
            refs_paths = get_outputs_refs_paths()['experiments'][0]
            print(refs_paths)
            load_model = os.path.join(refs_paths, args.load)
            print(load_model)
            print("LOAD THE listener from %s" % load_model)
            start_iter = listner.load(load_model)
        else:
            print("LOAD THE listener from %s" % args.load)
            start_iter = listner.load(
                os.path.join(args.R2R_Aux_path, args.load))

    start = time.time()

    best_val = {
        'val_seen': {
            "accu": 0.,
            "state": "",
            'update': False
        },
        'val_unseen': {
            "accu": 0.,
            "state": "",
            'update': False
        }
    }
    if args.fast_train:
        log_every = 40
    for idx in range(start_iter, start_iter + n_iters, log_every):
        listner.logs = defaultdict(list)
        interval = min(log_every, start_iter + n_iters - idx)
        iter = idx + interval

        # Train for log_every interval
        if aug_env is None:  # The default training process
            listner.env = train_env
            listner.train(interval,
                          feedback=feedback_method)  # Train interval iters
        else:
            if args.accumulate_grad:
                for _ in range(interval // 2):
                    listner.zero_grad()
                    listner.env = train_env

                    # Train with GT data
                    args.ml_weight = 0.2
                    listner.accumulate_gradient(feedback_method)
                    listner.env = aug_env

                    # Train with Back Translation
                    args.ml_weight = 0.6  # Sem-Configuration
                    listner.accumulate_gradient(feedback_method,
                                                speaker=speaker)
                    listner.optim_step()
            else:
                for _ in range(interval // 2):
                    # Train with GT data
                    listner.env = train_env
                    args.ml_weight = 0.2
                    listner.train(1, feedback=feedback_method)

                    # Train with Back Translation
                    listner.env = aug_env
                    args.ml_weight = 0.6
                    listner.train(1, feedback=feedback_method, speaker=speaker)

        # Log the training stats to tensorboard
        total = max(sum(listner.logs['total']), 1)
        # import pdb; pdb.set_trace() # length_rl == length_ml ? entropy length
        assert (max(len(listner.logs['rl_loss']),
                    1) == max(len(listner.logs['ml_loss']), 1))
        max_rl_length = max(len(listner.logs['critic_loss']), 1)
        log_length = max(len(listner.logs['rl_loss']), 1)
        rl_loss = sum(listner.logs['rl_loss']) / log_length
        ml_loss = sum(listner.logs['ml_loss']) / log_length
        critic_loss = sum(listner.logs['critic_loss']
                          ) / log_length  #/ length / args.batchSize
        spe_loss = sum(listner.logs['spe_loss']) / log_length
        pro_loss = sum(listner.logs['pro_loss']) / log_length
        mat_loss = sum(listner.logs['mat_loss']) / log_length
        fea_loss = sum(listner.logs['fea_loss']) / log_length
        ang_loss = sum(listner.logs['ang_loss']) / log_length
        entropy = sum(
            listner.logs['entropy']) / log_length  #/ length / args.batchSize
        predict_loss = sum(listner.logs['us_loss']) / log_length
        writer.add_scalar("loss/rl_loss", rl_loss, idx)
        writer.add_scalar("loss/ml_loss", ml_loss, idx)
        writer.add_scalar("policy_entropy", entropy, idx)
        writer.add_scalar("loss/spe_loss", spe_loss, idx)
        writer.add_scalar("loss/pro_loss", pro_loss, idx)
        writer.add_scalar("loss/mat_loss", mat_loss, idx)
        writer.add_scalar("loss/fea_loss", fea_loss, idx)
        writer.add_scalar("loss/ang_loss", ang_loss, idx)
        writer.add_scalar("total_actions", total, idx)
        writer.add_scalar("max_rl_length", max_rl_length, idx)
        writer.add_scalar("loss/critic", critic_loss, idx)
        writer.add_scalar("loss/unsupervised", predict_loss, idx)
        print("total_actions", total)
        print("max_rl_length", max_rl_length)

        # Run validation
        loss_str = ""
        for env_name, (env, evaluator) in val_envs.items():
            listner.env = env

            # Get validation loss under the same conditions as training
            iters = None if args.fast_train or env_name != 'train' else 20  # 20 * 64 = 1280

            # Get validation distance from goal under test evaluation conditions
            listner.test(use_dropout=False, feedback='argmax', iters=iters)
            result = listner.get_results()
            score_summary, _ = evaluator.score(result)
            loss_str += "%s " % env_name
            for metric, val in score_summary.items():
                if metric in ['success_rate']:
                    loss_str += ', %s: %.4f' % (metric, val)
                    writer.add_scalar("%s/accuracy" % env_name, val, idx)
                    if env_name in best_val:
                        if val > best_val[env_name]['accu']:
                            best_val[env_name]['accu'] = val
                            best_val[env_name]['update'] = True
                if metric in ['spl']:
                    writer.add_scalar("%s/spl" % env_name, val, idx)
                    loss_str += ', %s: %.4f' % (metric, val)
            loss_str += '\n'
        loss_str += '\n'

        for env_name in best_val:
            if best_val[env_name]['update']:
                best_val[env_name]['state'] = 'Iter %d \n%s' % (iter, loss_str)
                best_val[env_name]['update'] = False
                file_dir = os.path.join(output_dir, "snap", args.name,
                                        "state_dict", "best_%s" % (env_name))
                listner.save(idx, file_dir)
        print(('%s (%d %d%%) \n%s' % (timeSince(start,
                                                float(iter) / n_iters), iter,
                                      float(iter) / n_iters * 100, loss_str)))

        if iter % 1000 == 0:
            print("BEST RESULT TILL NOW")
            for env_name in best_val:
                print(env_name, best_val[env_name]['state'])

        if iter % args.save_iter == 0:
            file_dir = os.path.join(output_dir, "snap", args.name,
                                    "state_dict", "Iter_%06d" % (iter))
            listner.save(idx, file_dir)
Ejemplo n.º 9
0
def train_speaker(train_env, tok, n_iters, log_every=500, val_envs={}):
    writer = SummaryWriter(logdir=log_dir)
    listner = Seq2SeqAgent(train_env, "", tok, args.maxAction)
    speaker = Speaker(train_env, listner, tok)

    if args.load is not None:
        print("LOAD THE Speaker from %s" % args.load)
        speaker.load(os.path.join(args.load))

    if args.fast_train:
        log_every = 40

    best_bleu = defaultdict(lambda: 0)
    best_loss = defaultdict(lambda: 1232)
    for idx in range(0, n_iters, log_every):
        interval = min(log_every, n_iters - idx)

        # Train for log_every interval
        speaker.env = train_env
        speaker.train(interval)  # Train interval iters

        print()
        print("Iter: %d" % idx)

        # Evaluation
        for env_name, (env, evaluator) in val_envs.items():
            if 'train' in env_name:  # Ignore the large training set for the efficiency
                continue

            print("............ Evaluating %s ............." % env_name)
            speaker.env = env
            path2inst, loss, word_accu, sent_accu = speaker.valid()
            path_id = next(iter(path2inst.keys()))
            print('path_id:', path_id)
            print("Inference: ", tok.decode_sentence(path2inst[path_id]))
            print("GT: ", evaluator.gt[str(path_id)]['instructions'])
            bleu_score, precisions = evaluator.bleu_score(path2inst)

            # Tensorboard log
            writer.add_scalar("bleu/%s" % (env_name), bleu_score, idx)
            writer.add_scalar("loss/%s" % (env_name), loss, idx)
            writer.add_scalar("word_accu/%s" % (env_name), word_accu, idx)
            writer.add_scalar("sent_accu/%s" % (env_name), sent_accu, idx)
            writer.add_scalar("bleu4/%s" % (env_name), precisions[3], idx)

            # Save the model according to the bleu score
            if bleu_score > best_bleu[env_name]:
                best_bleu[env_name] = bleu_score
                print('Save the model with %s BEST env bleu %0.4f' %
                      (env_name, bleu_score))
                speaker.save(
                    idx,
                    os.path.join(log_dir, 'state_dict',
                                 'best_%s_bleu' % env_name))

            if loss < best_loss[env_name]:
                best_loss[env_name] = loss
                print('Save the model with %s BEST env loss %0.4f' %
                      (env_name, loss))
                speaker.save(
                    idx,
                    os.path.join(log_dir, 'state_dict',
                                 'best_%s_loss' % env_name))

            # Screen print out
            print(
                "Bleu 1: %0.4f Bleu 2: %0.4f, Bleu 3 :%0.4f,  Bleu 4: %0.4f" %
                tuple(precisions))
Ejemplo n.º 10
0
def train(train_env, tok, n_iters, log_every=100, val_envs={}, aug_env=None,stok=None,press_env=None):
    writer = SummaryWriter(logdir=log_dir)
    listner = Seq2SeqAgent(train_env, "", tok, args.maxAction,encoder_type=args.encoderType)

    speaker = None
    if args.self_train:
        if args.encoderType in ['DicEncoder','CEncoder','Dic']:
            print("Note that u r using bert-based encoder,use speaker tokenizer to decode instructions!")
            speaker = Speaker(train_env, listner, stok)
        else:
            speaker = Speaker(train_env, listner, tok)

        if args.speaker is not None:
            print("Load the speaker from %s." % args.speaker)
            speaker.load(args.speaker)

    start_iter = 0
    if args.load is not None:
        print("LOAD THE listener from %s" % args.load)
        start_iter = listner.load(os.path.join(args.load))

    listner.encoder.bert.update_lang_bert,listner.encoder.bert.config.update_lang_bert = args.d_transformer_update, args.d_transformer_update
    listner.encoder.bert.update_add_layer, listner.encoder.bert.config.update_add_layer = args.d_update_add_layer, args.d_update_add_layer

    myidx = 0
    best_spl = 0
    best_sr_sum = 0
    start = time.time()

    best_val = {'val_seen': {"accu": 0., "state":"", 'update':False},
                'val_unseen': {"accu": 0., "state":"", 'update':False}}
    best_val_sr_sum = {'sr_sum': {"accu": 0., "state":"", 'update':False},
                'spl_unseen': {"accu": 0., "state":"", 'update':False}}
    if args.fast_train:
        log_every = 40
    for idx in range(start_iter, start_iter+n_iters, log_every):
        listner.logs = defaultdict(list)
        interval = min(log_every, n_iters-idx)
        iter = idx + interval

        myidx += interval
        print()
        print("PROGRESS: {}%".format(round((myidx) * 100 / n_iters, 4)))
        print()

        # Train for log_every interval
        if aug_env is None:     # The default training process
            listner.env = train_env
            listner.train(interval, feedback=feedback_method)   # Train interval iters
        else:
            if args.accumulate_grad:
                for _ in range(interval // 2):
                    listner.zero_grad()
                    listner.env = train_env

                    # Train with GT data
                    args.ml_weight = 0.2
                    listner.accumulate_gradient(feedback_method)
                    listner.env = aug_env

                    # Train with Back Translation
                    args.ml_weight = 0.6        # Sem-Configuration
                    listner.accumulate_gradient(feedback_method, speaker=speaker)
                    listner.optim_step()
            else:
                for _ in range(interval // 2):
                    # Train with GT data
                    listner.env = train_env
                    args.ml_weight = 0.2
                    listner.train(1, feedback=feedback_method)

                    # Train with Back Translation
                    listner.env = aug_env
                    args.ml_weight = 0.6
                    listner.train(1, feedback=feedback_method, speaker=speaker)

        # Log the training stats to tensorboard
        total = max(sum(listner.logs['total']), 1)
        length = max(len(listner.logs['critic_loss']), 1)
        critic_loss = sum(listner.logs['critic_loss']) / total #/ length / args.batchSize
        entropy = sum(listner.logs['entropy']) / total #/ length / args.batchSize
        predict_loss = sum(listner.logs['us_loss']) / max(len(listner.logs['us_loss']), 1)
        writer.add_scalar("loss/critic", critic_loss, idx)
        writer.add_scalar("policy_entropy", entropy, idx)
        writer.add_scalar("loss/unsupervised", predict_loss, idx)
        writer.add_scalar("total_actions", total, idx)
        writer.add_scalar("max_length", length, idx)
        print("total_actions", total)
        print("max_length", length)
        data_log['iteration'].append(iter)

        # Run validation
        loss_str = ""
        current_sr_sum = 0
        for env_name, (env, evaluator) in val_envs.items():
            listner.env = env

            # Get validation loss under the same conditions as training
            iters = None if args.fast_train or env_name != 'train' else 20     # 20 * 64 = 1280

            # Get validation distance from goal under test evaluation conditions
            listner.test(use_dropout=False, feedback='argmax', iters=iters)
            result = listner.get_results()
            score_summary, _ = evaluator.score(result)
            loss_str += ", %s " % env_name
            for metric,val in score_summary.items():
                if metric in ['success_rate']:
                    writer.add_scalar("accuracy/%s" % env_name, val, idx)
                    if env_name in best_val:
                        current_sr_sum += val
                        if val > best_val[env_name]['accu']:
                            best_val[env_name]['accu'] = val
                            best_val[env_name]['update'] = True
                loss_str += ', %s: %.3f' % (metric, val)

                if metric == 'spl' and env_name == 'val_unseen':
                    if val > best_spl:
                        best_spl = val
                        best_val_sr_sum['spl_unseen']['accu'] = best_spl
                        best_val_sr_sum['spl_unseen']['update'] = True

                data_log['%s %s' % (env_name, metric)].append(val)
        if current_sr_sum > best_sr_sum:
            best_sr_sum = current_sr_sum
            best_val_sr_sum['sr_sum']['accu'] = best_sr_sum
            best_val_sr_sum['sr_sum']['update'] = True


        print()
        print("EVALERR: {}%".format(best_spl))
        print()

        for env_name in best_val:
            if best_val[env_name]['update']:
                best_val[env_name]['state'] = 'Iter %d %s' % (iter, loss_str)
                best_val[env_name]['update'] = False
                if args.philly:
                    listner.save(idx, os.path.join(log_dir, "state_dict", "best_%s" % (env_name)))
                else:
                    listner.save(idx, os.path.join("snap", args.name, "state_dict", "best_%s" % (env_name)))

        for metric in best_val_sr_sum:
            if best_val_sr_sum[metric]['update']:
                best_val_sr_sum[metric]['state'] = 'Iter %d %s' % (iter, loss_str)
                best_val_sr_sum[metric]['update'] = False
                if args.philly:
                    listner.save(idx, os.path.join(log_dir, "state_dict", "best_%s" % (metric)))
                else:
                    listner.save(idx, os.path.join("snap", args.name, "state_dict", "best_%s" % (metric)))


        print(('%s (%d %d%%) %s' % (timeSince(start, float(iter)/n_iters),
                                             iter, float(iter)/n_iters*100, loss_str)))
        if iter % 1000 == 0:
            print("BEST RESULT TILL NOW")
            for env_name in best_val:
                print(env_name, best_val[env_name]['state'])


        df = pd.DataFrame(data_log)
        df.set_index('iteration')
        df_path = '%s/plot_log.csv' % (plot_dir)
        write_num = 0
        while (write_num < 20):
            try:
                df.to_csv(df_path)
                break
            except:
                write_num += 1

        #if iter % 50000 == 0:
        #    if args.philly:
        #        listner.save(idx, os.path.join(log_dir, "state_dict", "Iter_%06d" % (iter)))
        #    else:
        #        listner.save(idx, os.path.join(log_dir, "state_dict", "Iter_%06d" % (iter)))
        #    #listner.save(idx, os.path.join(log_dir, "state_dict", "Iter_%06d" % (iter)))

    if args.philly:
        listner.save(idx, os.path.join(log_dir, "state_dict", "LAST_iter%d" % (idx)))
    else:
        listner.save(idx, os.path.join('snap', args.name, "state_dict", "LAST_iter%d" % (idx)))
Ejemplo n.º 11
0
    #     torch_ds = torch.utils.data.Subset(torch_ds, range(1000))
    print("The size of data split %s is %d" % (split, len(torch_ds)))
    loader = torch.utils.data.DataLoader(
        torch_ds,
        batch_size=args.batch_size, shuffle=shuffle,
        num_workers=args.workers, pin_memory=True,
        drop_last=drop_last)
    return dataset, torch_ds, loader

if 'speaker' in args.train:
    train_tuple = get_tuple(args.dataset, 'train', shuffle=False, drop_last=True)
    valid_tuple = get_tuple(args.dataset, 'valid', shuffle=False, drop_last=False)
    speaker = Speaker(train_tuple[0])   # [0] is the dataset
    if args.load is not None:
        print("Load speaker from %s." % args.load)
        speaker.load(args.load)
        scores, result = speaker.evaluate(valid_tuple)
        print("Have result for %d data" % len(result))
        print("The validation result is:")
        print(scores)
    if args.train == 'speaker':
        speaker.train(train_tuple, valid_tuple, args.epochs)
    if args.train == 'rlspeaker':
        speaker.train(train_tuple, valid_tuple, args.epochs, rl=True)
    elif args.train == 'validspeaker':
        scores, result = speaker.evaluate(valid_tuple)
        print(scores)
    elif args.train == 'testspeaker':
        test_tuple = get_tuple(args.dataset, 'test', shuffle=False, drop_last=False)
        scores, result = speaker.evaluate(test_tuple)
        print("Test:")