コード例 #1
0
def train_speaker(train_env, tok, n_iters, log_every=500, val_envs={}):
    writer = SummaryWriter(logdir=log_dir)
    listner = Seq2SeqAgent(train_env, "", tok, args.maxAction)
    speaker = Speaker(train_env, listner, tok)

    if args.fast_train:
        log_every = 40

    best_bleu = defaultdict(lambda: 0)
    best_loss = defaultdict(lambda: 1232)
    for idx in range(0, n_iters, log_every):
        interval = min(log_every, n_iters - idx)

        # Train for log_every interval
        speaker.env = train_env
        speaker.train(interval)   # Train interval iters

        print()
        print("Iter: %d" % idx)

        # Evaluation
        for env_name, (env, evaluator) in val_envs.items():
            if 'train' in env_name: # Ignore the large training set for the efficiency
                continue

            print("............ Evaluating %s ............." % env_name)
            speaker.env = env
            path2inst, loss, word_accu, sent_accu = speaker.valid()
            path_id = next(iter(path2inst.keys()))
            print("Inference: ", tok.decode_sentence(path2inst[path_id]))
            print("GT: ", evaluator.gt[str(path_id)]['instructions'])
            bleu_score, precisions = evaluator.bleu_score(path2inst)

            # Tensorboard log
            writer.add_scalar("bleu/%s" % (env_name), bleu_score, idx)
            writer.add_scalar("loss/%s" % (env_name), loss, idx)
            writer.add_scalar("word_accu/%s" % (env_name), word_accu, idx)
            writer.add_scalar("sent_accu/%s" % (env_name), sent_accu, idx)
            writer.add_scalar("bleu4/%s" % (env_name), precisions[3], idx)

            # Save the model according to the bleu score
            if bleu_score > best_bleu[env_name]:
                best_bleu[env_name] = bleu_score
                print('Save the model with %s BEST env bleu %0.4f' % (env_name, bleu_score))
                speaker.save(idx, os.path.join(log_dir, 'state_dict', 'best_%s_bleu' % env_name))

            if loss < best_loss[env_name]:
                best_loss[env_name] = loss
                print('Save the model with %s BEST env loss %0.4f' % (env_name, loss))
                speaker.save(idx, os.path.join(log_dir, 'state_dict', 'best_%s_loss' % env_name))

            # Screen print out
            print("Bleu 1: %0.4f Bleu 2: %0.4f, Bleu 3 :%0.4f,  Bleu 4: %0.4f" % tuple(precisions))
コード例 #2
0
ファイル: FinalGUI.py プロジェクト: thewolfe1/TheThirdEye
    def next(self):  #starts the function to train the model

        model = Speech()
        model.expend_data()
        model.read_data()
        model.preprocess_labels()
        model.preprocess_data()
        model.train(10, 4, 1, 1000)

        model2 = Speaker()
        model2.audio_to_image()
        model2.preprocess()
        model2.train()

        self.switch_window.emit()
コード例 #3
0
ファイル: FinalGUI.py プロジェクト: thewolfe1/TheThirdEye
    def TrainModel(self):
        '''
        Trains the model wit given recordings
        :param recordings from
        '''

        model = Speech()
        model.expend_data()
        model.read_data()
        model.preprocess_labels()
        model.preprocess_data()
        model.train(10, 4, 1, 1000)

        model2 = Speaker()
        model2.audio_to_image()
        model2.preprocess()
        model2.train()
コード例 #4
0
ファイル: main.py プロジェクト: pcuenca/VisualRelationships
        drop_last=drop_last)
    return dataset, torch_ds, loader

if 'speaker' in args.train:
    train_tuple = get_tuple(args.dataset, 'train', shuffle=False, drop_last=True)
    valid_tuple = get_tuple(args.dataset, 'valid', shuffle=False, drop_last=False)
    speaker = Speaker(train_tuple[0])   # [0] is the dataset
    if args.load is not None:
        print("Load speaker from %s." % args.load)
        speaker.load(args.load)
        scores, result = speaker.evaluate(valid_tuple)
        print("Have result for %d data" % len(result))
        print("The validation result is:")
        print(scores)
    if args.train == 'speaker':
        speaker.train(train_tuple, valid_tuple, args.epochs)
    if args.train == 'rlspeaker':
        speaker.train(train_tuple, valid_tuple, args.epochs, rl=True)
    elif args.train == 'validspeaker':
        scores, result = speaker.evaluate(valid_tuple)
        print(scores)
    elif args.train == 'testspeaker':
        test_tuple = get_tuple(args.dataset, 'test', shuffle=False, drop_last=False)
        scores, result = speaker.evaluate(test_tuple)
        print("Test:")
        print("Have result for %d data" % len(result))
        print(scores)
        import json
        json.dump(result, open("test_result.json", 'w'))
elif 'nlvr' in args.train:
    train_tuple = get_tuple(args.dataset, 'train', task='nlvr', shuffle=False, drop_last=True)