コード例 #1
0
    def show_samples(self,
                     dataset,
                     num_samples=10,
                     cuda=True,
                     logger=None,
                     decoding_strategy='sample',
                     indices=None,
                     beam_width=4):
        if indices is None:
            indices = list()
            for _ in range(num_samples):
                indices.append(random.randint(0, len(dataset) - 1))

        collate_fn = get_collate_fn(cuda)

        data = [dataset[ind] for ind in indices]
        batch = collate_fn(data)

        out = self.forward(batch[0],
                           decoding_strategy=decoding_strategy,
                           train=False,
                           beam_width=beam_width)

        generated_utterance = out['utterance'].cpu().data
        logger_fn = print
        if logger:
            logger_fn = logger

        for i in range(len(indices)):
            o = ''
            for obs in data[i][0]['goldstandard']:
                o += '(' + ','.join(
                    [dataset.map.landmark_dict.decode(o_ind)
                     for o_ind in obs]) + ') ,'
            # a = ', '.join([i2act[a_ind] for a_ind in actions[i]])
            a = ','.join([
                dataset.act_dict.decode(a_ind)
                for a_ind in data[i][0]['actions']
            ])

            logger_fn('Observations: ' + o)
            logger_fn('Actions: ' + a)
            logger_fn('GT: ' +
                      dataset.dict.decode(batch[0]['utterance'][i, 1:]))
            logger_fn('Sample: ' +
                      dataset.dict.decode(generated_utterance[i, :]))
            logger_fn('-' * 80)
コード例 #2
0
    parser.add_argument('--report-every', type=int, default=5)
    parser.add_argument('--num-epochs', type=int, default=500, help='Number of epochs')

    args = parser.parse_args()

    exp_dir = os.path.join(args.exp_dir, args.exp_name)
    if not os.path.exists(exp_dir):
        os.mkdir(exp_dir)

    args = parser.parse_args()

    logger = create_logger(os.path.join(exp_dir, 'log.txt'))
    logger.info(args)

    train_data = TalkTheWalkEmergent(args.data_dir, 'train', goldstandard_features=True, T=args.T)
    train_loader = DataLoader(train_data, args.batch_sz, collate_fn=get_collate_fn(args.cuda), shuffle=True)

    valid_data = TalkTheWalkEmergent(args.data_dir, 'valid', goldstandard_features=True, T=args.T)
    valid_loader = DataLoader(valid_data, args.batch_sz, collate_fn=get_collate_fn(args.cuda))

    test_data = TalkTheWalkEmergent(args.data_dir, 'test', goldstandard_features=True, T=args.T)
    test_loader = DataLoader(test_data, args.batch_sz, collate_fn=get_collate_fn(args.cuda))

    guide = GuideContinuous(args.vocab_sz, len(train_data.map.landmark_dict),
                            apply_masc=args.apply_masc, T=args.T)
    tourist = TouristContinuous(args.vocab_sz, len(train_data.map.landmark_dict), len(train_data.act_dict),
                                apply_masc=args.apply_masc, T=args.T)

    params = list(tourist.parameters()) + list(guide.parameters())
    opt = optim.Adam(params)
コード例 #3
0
    logger = create_logger(os.path.join(exp_dir, 'log.txt'))
    logger.info(args)

    data = TalkTheWalkLandmarks(args.data_dir, args.resnet_features,
                                args.fasttext_features,
                                args.textrecog_features)

    train_data, valid_data = create_split(data)
    add_weights(train_data, valid_data)
    train_data = DatasetHolder(train_data)
    valid_data = DatasetHolder(valid_data)

    train_loader = DataLoader(train_data,
                              args.batch_sz,
                              collate_fn=get_collate_fn(args.cuda),
                              shuffle=True)
    valid_loader = DataLoader(valid_data,
                              args.batch_sz,
                              collate_fn=get_collate_fn(args.cuda))

    target = numpy.array(
        [valid_data[i]['target'] for i in range(len(valid_data))])
    ones = numpy.ones_like(target)
    rand = numpy.random.randint(2, size=target.shape)

    logger.info('Baselines' + '-' * 70)
    logger.info("All positive: {}, {}, {}".format(
        f1_score(target, ones, average='weighted'),
        precision_score(target, ones, average='weighted'),
        recall_score(target, ones, average='weighted')))
コード例 #4
0
    data_dir = args.data_dir

    if args.trajectories == 'all':
        dictionary = Dictionary(file=os.path.join(data_dir, 'dict.txt'), min_freq=3)
        train_data = TalkTheWalkEmergent(data_dir, 'train', T=args.T)
        train_data.dict = dictionary
        valid_data = TalkTheWalkEmergent(data_dir, 'valid', T=args.T)
        valid_data.dict = dictionary
        test_data = TalkTheWalkEmergent(data_dir, 'test', T=args.T)
        test_data.dict = dictionary
    elif args.trajectories == 'human':
        train_data = TalkTheWalkLanguage(data_dir, 'train')
        valid_data = TalkTheWalkLanguage(data_dir, 'valid')
        test_data = TalkTheWalkLanguage(data_dir, 'test')

    train_loader = DataLoader(train_data, args.batch_sz, collate_fn=get_collate_fn(args.cuda))
    valid_loader = DataLoader(valid_data, args.batch_sz, collate_fn=get_collate_fn(args.cuda))

    test_loader = DataLoader(test_data, args.batch_sz, collate_fn=get_collate_fn(args.cuda))

    tourist = TouristLanguage.load(args.tourist_model)
    if args.guide_model is not None:
        guide = GuideLanguage.load(args.guide_model)
    else:
        guide = GuideLanguage(128, 256, len(train_data.dict), apply_masc=True, T=3)

    if args.cuda:
        tourist = tourist.cuda()
        guide = guide.cuda()

    if args.train_guide:
コード例 #5
0
                        default=50,
                        help='Number of epochs')

    args = parser.parse_args()

    exp_dir = os.path.join(args.exp_dir, args.exp_name)
    if not os.path.exists(exp_dir):
        os.mkdir(exp_dir)

    logger = create_logger(os.path.join(exp_dir, 'log.txt'))
    logger.info(args)

    train_data = TalkTheWalkLanguage(args.data_dir, 'train', args.last_turns)
    train_loader = DataLoader(train_data,
                              args.batch_sz,
                              collate_fn=get_collate_fn(args.cuda),
                              shuffle=True)

    valid_data = TalkTheWalkLanguage(args.data_dir, 'valid', args.last_turns)
    valid_loader = DataLoader(valid_data,
                              args.batch_sz,
                              collate_fn=get_collate_fn(args.cuda))

    test_data = TalkTheWalkLanguage(args.data_dir, 'test', args.last_turns)
    test_loader = DataLoader(test_data,
                             args.batch_sz,
                             collate_fn=get_collate_fn(args.cuda))

    guide = GuideLanguage(args.embed_sz,
                          args.hidden_sz,
                          len(train_data.dict),