def show_samples(self, dataset, num_samples=10, cuda=True, logger=None, decoding_strategy='sample', indices=None, beam_width=4): if indices is None: indices = list() for _ in range(num_samples): indices.append(random.randint(0, len(dataset) - 1)) collate_fn = get_collate_fn(cuda) data = [dataset[ind] for ind in indices] batch = collate_fn(data) out = self.forward(batch[0], decoding_strategy=decoding_strategy, train=False, beam_width=beam_width) generated_utterance = out['utterance'].cpu().data logger_fn = print if logger: logger_fn = logger for i in range(len(indices)): o = '' for obs in data[i][0]['goldstandard']: o += '(' + ','.join( [dataset.map.landmark_dict.decode(o_ind) for o_ind in obs]) + ') ,' # a = ', '.join([i2act[a_ind] for a_ind in actions[i]]) a = ','.join([ dataset.act_dict.decode(a_ind) for a_ind in data[i][0]['actions'] ]) logger_fn('Observations: ' + o) logger_fn('Actions: ' + a) logger_fn('GT: ' + dataset.dict.decode(batch[0]['utterance'][i, 1:])) logger_fn('Sample: ' + dataset.dict.decode(generated_utterance[i, :])) logger_fn('-' * 80)
parser.add_argument('--report-every', type=int, default=5) parser.add_argument('--num-epochs', type=int, default=500, help='Number of epochs') args = parser.parse_args() exp_dir = os.path.join(args.exp_dir, args.exp_name) if not os.path.exists(exp_dir): os.mkdir(exp_dir) args = parser.parse_args() logger = create_logger(os.path.join(exp_dir, 'log.txt')) logger.info(args) train_data = TalkTheWalkEmergent(args.data_dir, 'train', goldstandard_features=True, T=args.T) train_loader = DataLoader(train_data, args.batch_sz, collate_fn=get_collate_fn(args.cuda), shuffle=True) valid_data = TalkTheWalkEmergent(args.data_dir, 'valid', goldstandard_features=True, T=args.T) valid_loader = DataLoader(valid_data, args.batch_sz, collate_fn=get_collate_fn(args.cuda)) test_data = TalkTheWalkEmergent(args.data_dir, 'test', goldstandard_features=True, T=args.T) test_loader = DataLoader(test_data, args.batch_sz, collate_fn=get_collate_fn(args.cuda)) guide = GuideContinuous(args.vocab_sz, len(train_data.map.landmark_dict), apply_masc=args.apply_masc, T=args.T) tourist = TouristContinuous(args.vocab_sz, len(train_data.map.landmark_dict), len(train_data.act_dict), apply_masc=args.apply_masc, T=args.T) params = list(tourist.parameters()) + list(guide.parameters()) opt = optim.Adam(params)
logger = create_logger(os.path.join(exp_dir, 'log.txt')) logger.info(args) data = TalkTheWalkLandmarks(args.data_dir, args.resnet_features, args.fasttext_features, args.textrecog_features) train_data, valid_data = create_split(data) add_weights(train_data, valid_data) train_data = DatasetHolder(train_data) valid_data = DatasetHolder(valid_data) train_loader = DataLoader(train_data, args.batch_sz, collate_fn=get_collate_fn(args.cuda), shuffle=True) valid_loader = DataLoader(valid_data, args.batch_sz, collate_fn=get_collate_fn(args.cuda)) target = numpy.array( [valid_data[i]['target'] for i in range(len(valid_data))]) ones = numpy.ones_like(target) rand = numpy.random.randint(2, size=target.shape) logger.info('Baselines' + '-' * 70) logger.info("All positive: {}, {}, {}".format( f1_score(target, ones, average='weighted'), precision_score(target, ones, average='weighted'), recall_score(target, ones, average='weighted')))
data_dir = args.data_dir if args.trajectories == 'all': dictionary = Dictionary(file=os.path.join(data_dir, 'dict.txt'), min_freq=3) train_data = TalkTheWalkEmergent(data_dir, 'train', T=args.T) train_data.dict = dictionary valid_data = TalkTheWalkEmergent(data_dir, 'valid', T=args.T) valid_data.dict = dictionary test_data = TalkTheWalkEmergent(data_dir, 'test', T=args.T) test_data.dict = dictionary elif args.trajectories == 'human': train_data = TalkTheWalkLanguage(data_dir, 'train') valid_data = TalkTheWalkLanguage(data_dir, 'valid') test_data = TalkTheWalkLanguage(data_dir, 'test') train_loader = DataLoader(train_data, args.batch_sz, collate_fn=get_collate_fn(args.cuda)) valid_loader = DataLoader(valid_data, args.batch_sz, collate_fn=get_collate_fn(args.cuda)) test_loader = DataLoader(test_data, args.batch_sz, collate_fn=get_collate_fn(args.cuda)) tourist = TouristLanguage.load(args.tourist_model) if args.guide_model is not None: guide = GuideLanguage.load(args.guide_model) else: guide = GuideLanguage(128, 256, len(train_data.dict), apply_masc=True, T=3) if args.cuda: tourist = tourist.cuda() guide = guide.cuda() if args.train_guide:
default=50, help='Number of epochs') args = parser.parse_args() exp_dir = os.path.join(args.exp_dir, args.exp_name) if not os.path.exists(exp_dir): os.mkdir(exp_dir) logger = create_logger(os.path.join(exp_dir, 'log.txt')) logger.info(args) train_data = TalkTheWalkLanguage(args.data_dir, 'train', args.last_turns) train_loader = DataLoader(train_data, args.batch_sz, collate_fn=get_collate_fn(args.cuda), shuffle=True) valid_data = TalkTheWalkLanguage(args.data_dir, 'valid', args.last_turns) valid_loader = DataLoader(valid_data, args.batch_sz, collate_fn=get_collate_fn(args.cuda)) test_data = TalkTheWalkLanguage(args.data_dir, 'test', args.last_turns) test_loader = DataLoader(test_data, args.batch_sz, collate_fn=get_collate_fn(args.cuda)) guide = GuideLanguage(args.embed_sz, args.hidden_sz, len(train_data.dict),