def get_datasets(config_data):
    x_transform_train, x_transform_val, y_transform, y_transform_val = get_transforms(
        config_data)
    data_dir = config_data['dataset']['path']
    train_fraction = config_data['dataset']['training_fraction']
    timeseries = config_data['dataset']['timeseries']
    timesteps = config_data['model']['fgru_timesteps']

    # Build Processed Dataset if it doesn't exist
    make_acdc_dataset(data_dir, timeseries, timesteps)
    x_train_file, y_train_file, x_val_file, y_val_file = build_dataset_files(
        data_dir, train_fraction)

    ds_train = SimpleDataset(x_train_file,
                             y_train_file,
                             x_transform=x_transform_train,
                             y_transform=y_transform,
                             use_cache=True)
    ds_val = SimpleDataset(x_val_file,
                           y_val_file,
                           x_transform=x_transform_val,
                           y_transform=y_transform_val,
                           use_cache=True)

    return ds_train, ds_val
Beispiel #2
0
    def make_sample_dataloader(self, day_inputs, day_gov_inputs, gbm_outputs, outputs, shuffle=False):
        if self.config.use_saintdataset:
            dataset = SAINTDataset(
                [day_inputs, day_gov_inputs, gbm_outputs, outputs],
                self.edge_index, self.edge_weight, self.config.num_nodes,
                self.config.batch_size, shuffle=shuffle,
                shuffle_order=self.config.saint_shuffle_order,
                saint_sample_type=self.config.saint_sample_type,
                saint_batch_size=self.config.saint_batch_size,
                saint_walk_length=self.config.saint_walk_length,
            )

            return DataLoader(dataset, batch_size=None)
        else:
            dataset = SimpleDataset([day_inputs, day_gov_inputs, gbm_outputs, outputs])
            def collate_fn(samples):
                day_inputs = torch.cat([item[0][0] for item in samples]).unsqueeze(0)   # [1,bs,seq_length,feature_dim]
                day_gov_inputs = torch.cat([item[0][1] for item in samples]).unsqueeze(0)   # [1,bs,seq_length,feature_dim]
                gbm_outputs = torch.cat([item[0][-2] for item in samples]).unsqueeze(0)
                outputs = torch.cat([item[0][-1] for item in samples]).unsqueeze(0)
                node_ids = torch.LongTensor([item[1] for item in samples])   # [bs]
                date_ids = torch.LongTensor([item[2] for item in samples])   # [bs]
                return [[day_inputs, day_gov_inputs, gbm_outputs, outputs], {'cent_n_id':node_ids,'type':'random'}, date_ids]

            return DataLoader(dataset, batch_size=self.config.batch_size, shuffle=shuffle, collate_fn=collate_fn)
Beispiel #3
0
def setup_train_cross_dataset(splits, epoch, args):
    test_th = epoch % len(splits)
    train_pairs = []
    cross_pairs = splits[test_th]
    for i, split in enumerate(splits):
        if i == test_th:
            continue
        train_pairs += split

    if len(train_pairs) > args.train_amount:
        train_pairs = random.sample(train_pairs, args.train_amount)

    if len(cross_pairs) > args.cross_val_amount:
        cross_pairs = random.sample(cross_pairs, args.cross_val_amount)

    return (SimpleDataset(train_pairs, args,
                          True), SimpleDataset(cross_pairs, args, False))
def get_test_dataset(config_data):
    x_transform = get_test_transforms(config_data)
    data_dir = config_data['dataset']['path']
    timeseries = config_data['dataset']['timeseries']
    timesteps = config_data['model']['fgru_timesteps']

    # Build Processed Dataset if it doesn't exist
    test_file = make_acdc_test_dataset(data_dir, timeseries, timesteps)
    ds_test = SimpleDataset(test_file, x_transform=x_transform, use_cache=True)
    return ds_test
Beispiel #5
0
 def get_data_loader(self,
                     mode,
                     batch_size,
                     aug,
                     num_workers=8,
                     lazy_load=False):
     transform = self.trans_loader.get_composed_transform(aug)
     indexes = self.train_indexes if mode == 'train' else self.test_indexes
     shuffle = mode == 'train'  # no shuffle at test time
     dataset = SimpleDataset(self.data_file,
                             transform,
                             indexes=indexes,
                             lazy_load=lazy_load)
     data_loader_params = dict(batch_size=batch_size,
                               shuffle=shuffle,
                               num_workers=num_workers,
                               pin_memory=True)
     data_loader = torch.utils.data.DataLoader(dataset,
                                               **data_loader_params)
     return data_loader
Beispiel #6
0
def topic_words_examine(num_of_topics):
    dataset = SimpleDataset(validation_percentage=validation_percentage,
                            dataset_name=dataset_name)
    loader = DataLoader(data='valid',
                        simple_dataset=dataset,
                        dataset_name='sem-2016',
                        padding=False)
    # for idx, sentence in enumerate(dataset.valid_original_sentence):
    #     print(idx)
    #     print(sentence)
    #     print(dataset.valid_data[idx][0])
    #     print(loader[idx][1])
    # f, ax = plt.subplots(figsize=(9, 6))
    # flights = flights_long.pivot("month", "year", "passengers")
    # print(type(flights))
    # exit()
    # topic_words = [[] for i in range(num_of_topics)]
    # item = loader[sentence_idx][0]
    # preprocessed_sentence = dataset.valid_data[sentence_idx][0]
    # orig_sentence = dataset.valid_original_sentence[sentence_idx]
    # translator = str.maketrans(string.punctuation, ' ' * len(string.punctuation))
    # orig_sentence = orig_sentence.translate(translator)
    # orig_sentence = orig_sentence.split()
    # weights, existence = get_sentence_weights('./topic-attention', item)

    topics = []
    for sentence_idx, data in enumerate(loader):
        item = loader[sentence_idx][0]
        preprocessed_sentence = dataset.valid_data[sentence_idx][0]
        orig_sentence = dataset.valid_original_sentence[sentence_idx]
        weights, existence = get_sentence_weights('./topic-attention', item)
        stopwords_english = set(stopwords.words('english'))
        attention_weights = []
        words = []
        weights = [
            list(weight.squeeze(0).squeeze(-1).detach().numpy())
            for weight in weights
        ]
        most_important = [np.argmax(weight) for weight in weights]
        most_important = [preprocessed_sentence[i] for i in most_important]
        print(most_important)
Beispiel #7
0
tokenizer = label_encoder._first_module().tokenizer

instance_encoder = label_encoder

model = DualEncoderModel(
    label_encoder,
    instance_encoder,
)
model = model.to(device)

# the whole label set
data_path = os.path.join(os.path.abspath(os.getcwd()), 'dataset', args.dataset)
all_labels = pd.read_json(os.path.join(data_path, 'lbl.json'), lines=True)
label_list = list(all_labels.title)
label_ids = list(all_labels.uid)
label_data = SimpleDataset(label_list, transform=tokenizer.encode)

# label dataloader for searching
sampler = SequentialSampler(label_data)
label_padding_func = lambda x: padding_util(x, tokenizer.pad_token_id, 64)
label_dataloader = DataLoader(label_data,
                              sampler=sampler,
                              batch_size=16,
                              collate_fn=label_padding_func)

# test data
data_path = os.path.join(os.path.abspath(os.getcwd()), 'dataset', args.dataset)
try:
    accelerator.print("load cache")
    all_instances = torch.load(
        os.path.join(data_path, 'all_passages_with_titles.json.cache.pt'))
Beispiel #8
0
# # Limit for Testing
# train_files = train_files[:129]
# val_files = val_files[:129]
# test_files = test_files[:129]

from dataset import SimpleDataset
# train_sd = SimpleDataset(os.path.join(args.data_path,'input/birdclef-2021/train_short_audio'),
#                          name='train',
#                          batch_size=BATCH_SIZE,
#                          is_test=True,
#                          files_list=train_files)
# train_ds = train_sd.get_dataset()

train_sd = SimpleDataset(os.path.join(args.data_path,
                                      'input/birdclef-2021/train_short_audio'),
                         name='train',
                         batch_size=BATCH_SIZE,
                         files_list=train_files)
train_ds = train_sd.get_dataset()
if args.online:
    run.tag('Training', 'Dataset')

val_sd = SimpleDataset(os.path.join(args.data_path,
                                    'input/birdclef-2021/train_short_audio'),
                       name='validation',
                       batch_size=BATCH_SIZE,
                       is_test=True,
                       files_list=val_files,
                       sr=sr)
val_ds = val_sd.get_dataset()
if args.online:
Beispiel #9
0
def generateResults(
    query_path,
    gallery_path,
    model=None,
    modelpath="/home/ankit/csce-625-person-re-identification/Siamese/trained_resnets/checkpoint_33.tar"
):
    if model:
        model.cuda()
    else:
        checkpoint = torch.load(model_path)
        model = Siamese()
        model.cuda()
        model.load(checkpoint['state_dict'])

    imgTransforms = transforms.Compose([
        transforms.Resize(img_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std),
    ])
    query_list = sorted([filename
                         for _, _, filename in os.walk(query_path)][0])
    count = 1
    total = len(query_list)
    CMC = 0
    ap = 0

    raw_dists = []  #temporary

    for img in query_list:
        dists = []  #temporary

        if count % 10 == 1:
            print(f"starting {count}/{total}")

        query_tensor = imgTransforms(
            Image.open(os.path.join(query_path, img)).convert('RGB'))
        query_tensor = torch.autograd.Variable(query_tensor,
                                               volatile=True).cuda()
        query_tensor = query_tensor.view(1, query_tensor.size(0),
                                         query_tensor.size(1),
                                         query_tensor.size(2))

        gallery_loader = torch.utils.data.DataLoader(SimpleDataset(
            path=gallery_path, transforms=imgTransforms),
                                                     batch_size=1,
                                                     shuffle=False,
                                                     num_workers=4,
                                                     pin_memory=True)

        ### TEMPORARY SOLUTION TO THE GALLERY IMAGE PROBLEM ###
        if count == 1:
            raw_dists = getDists(query_tensor, gallery_loader, model)

        # compute output
        model.eval()
        query_feature, _ = model(query_tensor, query_tensor)

        for dist in raw_dists:
            dists.append((F.mse_loss(dist[0], query_feature), dist[1]))

        #######################################################

        dists.sort(key=lambda val: val[0])

        bestImgNames = [el[1][0] for el in dists]

        #print(bestImgNames)

        ap_tmp, CMC_tmp = createStats(img, bestImgNames)
        if CMC_tmp[0] == -1:
            continue

        CMC += CMC_tmp
        ap += ap_tmp
        count += 1
    CMC = CMC.float()
    CMC /= len(query_list)
    ap /= len(query_list)

    #print('top1: %.4f, top5: %.4f, top10: %.4f, mAP: %.4f' % (CMC[0], CMC[4], CMC[9], ap))
    return (CMC[0], CMC[4], CMC[9], ap)
Beispiel #10
0
def sentence_weight_examine(sentence_idx):
    dataset = SimpleDataset(validation_percentage=validation_percentage,
                            dataset_name=dataset_name)
    loader = DataLoader(data='valid',
                        simple_dataset=dataset,
                        dataset_name='sem-2016',
                        padding=False)
    # for idx, sentence in enumerate(dataset.valid_original_sentence):
    #     print(idx)
    #     print(sentence)
    #     print(dataset.valid_data[idx][0])
    #     print(loader[idx][1])
    # f, ax = plt.subplots(figsize=(9, 6))
    # flights = flights_long.pivot("month", "year", "passengers")
    # print(type(flights))
    # exit()
    item = loader[sentence_idx][0]
    preprocessed_sentence = dataset.valid_data[sentence_idx][0]
    orig_sentence = dataset.valid_original_sentence[sentence_idx]
    translator = str.maketrans(string.punctuation,
                               ' ' * len(string.punctuation))
    orig_sentence = orig_sentence.translate(translator)
    orig_sentence = orig_sentence.split()
    weights, existence = get_sentence_weights('./topic-attention', item)

    topics = []
    for i in range(11):
        topics.append(str(i + 1))
    stopwords_english = set(stopwords.words('english'))
    attention_weights = []
    words = []
    for idx, weight in enumerate(weights):
        weight = list(weight.squeeze(0).squeeze(-1).detach().numpy())
        temp = []
        for word in orig_sentence:
            words.append(word)
            if word.lower() in stopwords_english:
                temp.append(float(0.0))
            else:
                temp.append(
                    float(weight[preprocessed_sentence.index(word.lower())]))
        attention_weights.append(np.array(temp))
    attention_weights = np.array(attention_weights)
    attention_weights = attention_weights.transpose()

    fig, ax = plt.subplots(figsize=(10, 10))
    sns.heatmap(attention_weights,
                annot=True,
                fmt=".2f",
                ax=ax,
                xticklabels=topics,
                yticklabels=orig_sentence,
                cmap="YlGnBu",
                cbar_kws={'label': 'Attention Weights'})
    plt.xlabel('Topics')
    plt.ylabel('Sentence')
    plt.savefig('./attention_heatmap/valid_sentence_#' + str(sentence_idx) +
                '_attention_weights')

    topic_probs = []
    for idx in range(len(existence)):
        prob = float(existence[idx])
        topic_probs.append(prob)
    topic_probs = [np.array(topic_probs)]
    topic_probs = np.array(topic_probs).transpose()
    fig, ax = plt.subplots(figsize=(10, 10))
    sns.heatmap(topic_probs,
                annot=True,
                fmt=".2f",
                ax=ax,
                xticklabels=['Topic Probabilities'],
                yticklabels=topics,
                cmap="YlGnBu")
    plt.savefig('./attention_heatmap/valid_sentence_#' + str(sentence_idx) +
                '_topic_probs')
Beispiel #11
0
def main():
    # Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
    args = parse_args()
    distributed_args = accelerate.DistributedDataParallelKwargs(
        find_unused_parameters=True)
    accelerator = Accelerator(kwargs_handlers=[distributed_args])
    device = accelerator.device
    # Make one log on every process with the configuration for debugging.
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        filename=f'xmc_{args.dataset}_{args.mode}_{args.log}.log',
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
    logger.info(accelerator.state)

    # Setup logging, we only want one process per machine to log things on the screen.
    # accelerator.is_local_main_process is only True for one process per machine.
    logger.setLevel(
        logging.INFO if accelerator.is_local_main_process else logging.ERROR)
    ch = logging.StreamHandler(sys.stdout)
    logger.addHandler(ch)
    if accelerator.is_local_main_process:
        transformers.utils.logging.set_verbosity_info()
    else:
        transformers.utils.logging.set_verbosity_error()

    logger.info(sent_trans.__file__)

    # If passed along, set the training seed now.
    if args.seed is not None:
        set_seed(args.seed)

    # Load pretrained model and tokenizer
    if args.model_name_or_path == 'bert-base-uncased' or args.model_name_or_path == 'sentence-transformers/paraphrase-mpnet-base-v2':
        query_encoder = build_encoder(
            args.model_name_or_path,
            args.max_label_length,
            args.pooling_mode,
            args.proj_emb_dim,
        )
    else:
        query_encoder = sent_trans.SentenceTransformer(args.model_name_or_path)

    tokenizer = query_encoder._first_module().tokenizer

    block_encoder = query_encoder

    model = DualEncoderModel(query_encoder, block_encoder, args.mode)
    model = model.to(device)

    # the whole label set
    data_path = os.path.join(os.path.abspath(os.getcwd()), 'dataset',
                             args.dataset)
    all_labels = pd.read_json(os.path.join(data_path, 'lbl.json'), lines=True)
    label_list = list(all_labels.title)
    label_ids = list(all_labels.uid)
    label_data = SimpleDataset(label_list, transform=tokenizer.encode)

    # label dataloader for searching
    sampler = SequentialSampler(label_data)
    label_padding_func = lambda x: padding_util(x, tokenizer.pad_token_id, 64)
    label_dataloader = DataLoader(label_data,
                                  sampler=sampler,
                                  batch_size=16,
                                  collate_fn=label_padding_func)

    # label dataloader for regularization
    reg_sampler = RandomSampler(label_data)
    reg_dataloader = DataLoader(label_data,
                                sampler=reg_sampler,
                                batch_size=4,
                                collate_fn=label_padding_func)

    if args.mode == 'ict':
        train_data = ICTXMCDataset(tokenizer=tokenizer, dataset=args.dataset)
    elif args.mode == 'self-train':
        train_data = PosDataset(tokenizer=tokenizer,
                                dataset=args.dataset,
                                labels=label_list,
                                mode=args.mode)
    elif args.mode == 'finetune-pair':
        train_path = os.path.join(data_path, 'trn.json')
        pos_pair = []
        with open(train_path) as fp:
            for i, line in enumerate(fp):
                inst = json.loads(line.strip())
                inst_id = inst['uid']
                for ind in inst['target_ind']:
                    pos_pair.append((inst_id, ind, i))
        dataset_size = len(pos_pair)
        indices = list(range(dataset_size))
        split = int(np.floor(args.ratio * dataset_size))
        np.random.shuffle(indices)
        train_indices = indices[:split]
        torch.distributed.broadcast_object_list(train_indices,
                                                src=0,
                                                group=None)
        sample_pairs = [pos_pair[i] for i in train_indices]
        train_data = PosDataset(tokenizer=tokenizer,
                                dataset=args.dataset,
                                labels=label_list,
                                mode=args.mode,
                                sample_pairs=sample_pairs)
    elif args.mode == 'finetune-label':
        label_index = []
        label_path = os.path.join(data_path, 'label_index.json')
        with open(label_path) as fp:
            for line in fp:
                label_index.append(json.loads(line.strip()))
        np.random.shuffle(label_index)
        sample_size = int(np.floor(args.ratio * len(label_index)))
        sample_label = label_index[:sample_size]
        torch.distributed.broadcast_object_list(sample_label,
                                                src=0,
                                                group=None)
        sample_pairs = []
        for i, label in enumerate(sample_label):
            ind = label['ind']
            for inst_id in label['instance']:
                sample_pairs.append((inst_id, ind, i))
        train_data = PosDataset(tokenizer=tokenizer,
                                dataset=args.dataset,
                                labels=label_list,
                                mode=args.mode,
                                sample_pairs=sample_pairs)

    train_sampler = RandomSampler(train_data)
    padding_func = lambda x: ICT_batchify(x, tokenizer.pad_token_id, 64, 288)
    train_dataloader = torch.utils.data.DataLoader(
        train_data,
        sampler=train_sampler,
        batch_size=args.per_device_train_batch_size,
        num_workers=4,
        pin_memory=False,
        collate_fn=padding_func)

    try:
        accelerator.print("load cache")
        all_instances = torch.load(
            os.path.join(data_path, 'all_passages_with_titles.json.cache.pt'))
        test_data = SimpleDataset(all_instances.values())
    except:
        all_instances = {}
        test_path = os.path.join(data_path, 'tst.json')
        if args.mode == 'ict':
            train_path = os.path.join(data_path, 'trn.json')
            train_instances = {}
            valid_passage_ids = train_data.valid_passage_ids
            with open(train_path) as fp:
                for line in fp:
                    inst = json.loads(line.strip())
                    train_instances[
                        inst['uid']] = inst['title'] + '\t' + inst['content']
            for inst_id in valid_passage_ids:
                all_instances[inst_id] = train_instances[inst_id]
        test_ids = []
        with open(test_path) as fp:
            for line in fp:
                inst = json.loads(line.strip())
                all_instances[
                    inst['uid']] = inst['title'] + '\t' + inst['content']
                test_ids.append(inst['uid'])
        simple_transform = lambda x: tokenizer.encode(
            x, max_length=288, truncation=True)
        test_data = SimpleDataset(list(all_instances.values()),
                                  transform=simple_transform)
        inst_num = len(test_data)

    sampler = SequentialSampler(test_data)
    sent_padding_func = lambda x: padding_util(x, tokenizer.pad_token_id, 288)
    instance_dataloader = DataLoader(test_data,
                                     sampler=sampler,
                                     batch_size=128,
                                     collate_fn=sent_padding_func)

    # prepare pairs
    reader = csv.reader(open(os.path.join(data_path, 'all_pairs.txt'),
                             encoding="utf-8"),
                        delimiter=" ")
    qrels = {}
    for id, row in enumerate(reader):
        query_id, corpus_id, score = row[0], row[1], int(row[2])
        if query_id not in qrels:
            qrels[query_id] = {corpus_id: score}
        else:
            qrels[query_id][corpus_id] = score

    logging.info("| |ICT_dataset|={} pairs.".format(len(train_data)))

    # Optimizer
    # Split weights in two groups, one with weight decay and the other not.
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [
                p for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            args.weight_decay,
        },
        {
            "params": [
                p for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            0.0,
        },
    ]
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=1e-8)

    # Prepare everything with our `accelerator`.
    model, optimizer, train_dataloader, label_dataloader, reg_dataloader, instance_dataloader = accelerator.prepare(
        model, optimizer, train_dataloader, label_dataloader, reg_dataloader,
        instance_dataloader)

    # Scheduler and math around the number of training steps.
    num_update_steps_per_epoch = math.ceil(
        len(train_dataloader) / args.gradient_accumulation_steps)
    # args.max_train_steps = 100000
    args.num_train_epochs = math.ceil(args.max_train_steps /
                                      num_update_steps_per_epoch)
    args.num_warmup_steps = int(0.1 * args.max_train_steps)
    lr_scheduler = get_scheduler(
        name=args.lr_scheduler_type,
        optimizer=optimizer,
        num_warmup_steps=args.num_warmup_steps,
        num_training_steps=args.max_train_steps,
    )

    # Train!
    total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
    logger.info("***** Running training *****")
    logger.info(f"  Num examples = {len(train_data)}")
    logger.info(f"  Num Epochs = {args.num_train_epochs}")
    logger.info(
        f"  Instantaneous batch size per device = {args.per_device_train_batch_size}"
    )
    logger.info(
        f"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}"
    )
    logger.info(
        f"  Gradient Accumulation steps = {args.gradient_accumulation_steps}")
    logger.info(f"  Learning Rate = {args.learning_rate}")
    logger.info(f"  Total optimization steps = {args.max_train_steps}")
    # Only show the progress bar once on each machine.
    progress_bar = tqdm(range(args.max_train_steps),
                        disable=not accelerator.is_local_main_process)
    completed_steps = 0
    from torch.cuda.amp import autocast
    scaler = torch.cuda.amp.GradScaler()
    cluster_result = eval_and_cluster(args, logger, completed_steps,
                                      accelerator.unwrap_model(model),
                                      label_dataloader, label_ids,
                                      instance_dataloader, inst_num, test_ids,
                                      qrels, accelerator)
    reg_iter = iter(reg_dataloader)
    trial_name = f"dim-{args.proj_emb_dim}-bs-{args.per_device_train_batch_size}-{args.dataset}-{args.log}-{args.mode}"
    for epoch in range(args.num_train_epochs):
        model.train()
        for step, batch in enumerate(train_dataloader):
            batch = tuple(t for t in batch)
            label_tokens, inst_tokens, indices = batch
            if args.mode == 'ict':
                try:
                    reg_data = next(reg_iter)
                except StopIteration:
                    reg_iter = iter(reg_dataloader)
                    reg_data = next(reg_iter)

            if cluster_result is not None:
                pseudo_labels = cluster_result[indices]
            else:
                pseudo_labels = indices
            with autocast():
                if args.mode == 'ict':
                    label_emb, inst_emb, inst_emb_aug, reg_emb = model(
                        label_tokens, inst_tokens, reg_data)
                    loss, stats_dict = loss_function_reg(
                        label_emb, inst_emb, inst_emb_aug, reg_emb,
                        pseudo_labels, accelerator)
                else:
                    label_emb, inst_emb = model(label_tokens,
                                                inst_tokens,
                                                reg_data=None)
                    loss, stats_dict = loss_function(label_emb, inst_emb,
                                                     pseudo_labels,
                                                     accelerator)
                loss = loss / args.gradient_accumulation_steps

            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
            if step % args.gradient_accumulation_steps == 0 or step == len(
                    train_dataloader) - 1:
                scaler.step(optimizer)
                scaler.update()
                lr_scheduler.step()
                optimizer.zero_grad()
                progress_bar.update(1)
                completed_steps += 1

            if completed_steps % args.logging_steps == 0:
                if args.mode == 'ict':
                    logger.info(
                        "| Epoch [{:4d}/{:4d}] Step [{:8d}/{:8d}] Total Loss {:.6e}  Contrast Loss {:.6e}  Reg Loss {:.6e}"
                        .format(
                            epoch,
                            args.num_train_epochs,
                            completed_steps,
                            args.max_train_steps,
                            stats_dict["loss"].item(),
                            stats_dict["contrast_loss"].item(),
                            stats_dict["reg_loss"].item(),
                        ))
                else:
                    logger.info(
                        "| Epoch [{:4d}/{:4d}] Step [{:8d}/{:8d}] Total Loss {:.6e}"
                        .format(
                            epoch,
                            args.num_train_epochs,
                            completed_steps,
                            args.max_train_steps,
                            stats_dict["loss"].item(),
                        ))
            if completed_steps % args.eval_steps == 0:
                cluster_result = eval_and_cluster(
                    args, logger, completed_steps,
                    accelerator.unwrap_model(model), label_dataloader,
                    label_ids, instance_dataloader, inst_num, test_ids, qrels,
                    accelerator)
                unwrapped_model = accelerator.unwrap_model(model)

                unwrapped_model.label_encoder.save(
                    f"{args.output_dir}/{trial_name}/label_encoder")
                unwrapped_model.instance_encoder.save(
                    f"{args.output_dir}/{trial_name}/instance_encoder")

            if completed_steps >= args.max_train_steps:
                break
Beispiel #12
0
def loi_test(args):
    metrics = utils.AverageContainer()

    if args.model == 'csrnet':
        args.loss_focus = 'cc'

    if args.pre == '':
        args.pre = 'weights/{}/last_model.pt'.format(args.save_dir)
    model = load_model(args)
    model.eval()

    # Get a pretrained fixed model for the flow prediction
    if args.loss_focus == 'cc':
        fe_model = P21Small(load_pretrained=True).cuda()
        if args.dataset == 'fudan':
            pre_fe = '20201123_122014_dataset-fudan_model-p21small_density_model-fixed-8_cc_weight-50_frames_between-5_epochs-400_lr_setting-adam_9'
        elif args.dataset == 'ucsd':
            pre_fe = '20201013_193544_dataset-ucsd_model-v332dilation_cc_weight-50_frames_between-2_epochs-750_loss_focus-fe_lr_setting-adam_2_resize_mode-bilinear'
        elif args.dataset == 'tub':
            pre_fe = '20201125_152055_dataset-tub_model-p21small_density_model-fixed-5_cc_weight-50_frames_between-5_epochs-350_lr_setting-adam_9'
        elif args.dataset == 'aicity':
            pre_fe = '20201126_192730_dataset-aicity_model-p21small_density_model-fixed-5_cc_weight-50_frames_between-5_epochs-350_lr_setting-adam_9'
        else:
            print("This dataset doesnt have flow only results")
            exit()

        fe_model.load_state_dict(
            torch.load('weights/{}/last_model.pt'.format(pre_fe)))
        fe_model.eval()

    results = []

    ucsd_total_count = [[], []]

    with torch.no_grad():
        # Right now on cross validation
        _, test_vids = load_test_dataset(args)
        for v_i, video in enumerate(test_vids):

            vid_result = []
            print("Video ({}): {}".format(v_i, video.get_path()))

            if args.eval_method == 'roi':
                skip_inbetween = False
            else:
                skip_inbetween = True

            video.generate_frame_pairs(distance=args.frames_between,
                                       skip_inbetween=skip_inbetween)
            dataset = SimpleDataset(video.get_frame_pairs(), args, False)
            dataloader = torch.utils.data.DataLoader(
                dataset, batch_size=1, num_workers=args.dataloader_workers)

            # ---- REMOVE WHEN DONE ---- #
            if args.loi_flow_width:
                print("Loading flow")
                frames1, frames2, _ = next(iter(dataloader))
                frames1 = frames1.cuda()
                frames2 = frames2.cuda()
                fe_output, _, _ = fe_model.forward(frames1, frames2)
                fe_speed = torch.norm(fe_output, dim=1)

                threshold = 1.0
                fe_speed_high2 = fe_speed > threshold
                avg_speed = fe_speed[fe_speed_high2].sum(
                ) / fe_speed_high2.sum()

                if args.loi_flow_smoothing:
                    loi_width = int(
                        (0.75 + avg_speed / 10 * 0.25) * args.loi_width)
                else:
                    loi_width = int((avg_speed / 9) * args.loi_width)
            else:
                loi_width = args.loi_width

            # Sometimes a single line is required for all the information
            if args.loi_level == 'moving_counting' or args.loi_level == 'take_image' or args.eval_method == 'roi' or len(
                    video.get_lines()) == 0:
                line = basic_entities.BasicLineSample(video, (0, 0),
                                                      (100, 150))
                line.set_crossed(1, 1)
                lines = [line]
                real_line = False
            else:
                lines = video.get_lines()
                real_line = True

            for l_i, line in enumerate(lines):
                # Circumvent some bugs. When no pedestrians at all are present in the frame. Skip
                if line.get_crossed()[0] + line.get_crossed()[1] == 0:
                    continue

                # Setup LOI
                image = video.get_frame(0).get_image()
                width, height = image.size
                image.close()
                point1, point2 = line.get_line()
                loi_model = loi.LOI_Calculator(point1,
                                               point2,
                                               img_width=width,
                                               img_height=height,
                                               crop_processing=True,
                                               loi_version=args.loi_version,
                                               loi_width=loi_width,
                                               loi_height=loi_width *
                                               args.loi_height)

                if args.loi_level != 'take_image':
                    loi_model.create_regions()
                else:
                    # Index of capture in the video
                    capt = 0

                if args.dataset == 'aicity':
                    roi = aicity.create_roi(video)

                per_frame = [[], []]
                crosses = line.get_crossed()

                pbar = tqdm(total=len(video.get_frame_pairs()))

                metrics['timing'].reset()

                for s_i, batch in enumerate(dataloader):

                    # Take 6 screens per video
                    if args.loi_level == 'take_image':
                        n_screens = len(video.get_frame_pairs())
                        every_screen = math.ceil(n_screens / 6)
                        if ((s_i + 1) % every_screen) != 0:
                            continue

                    torch.cuda.empty_cache()
                    timer = utils.sTimer("Full process time")

                    frame_pair = video.get_frame_pairs()[s_i]
                    print_i = '{:05d}'.format(s_i + 1)
                    frames1, frames2, densities, densities2 = batch
                    frames1 = loi_model.reshape_image(frames1.cuda())
                    frames2 = loi_model.reshape_image(frames2.cuda())

                    # Expand for CC only (CSRNet to optimize for real world applications)
                    if args.loss_focus == 'cc':
                        if args.model == 'csrnet':
                            cc_output = model(frames1)
                        else:
                            _, _, cc_output = model.forward(frames1, frames2)

                        fe_output, _, _ = fe_model.forward(frames1, frames2)
                    else:
                        fe_output, _, cc_output = model.forward(
                            frames1, frames2)

                    # Apply maxing!!!
                    if args.loi_maxing == 1:
                        if args.dataset == 'fudan':
                            fe_output = get_max_surrounding(fe_output,
                                                            surrounding=6,
                                                            only_under=True,
                                                            smaller_sides=True)
                            fe_output = get_max_surrounding(fe_output,
                                                            surrounding=6,
                                                            only_under=True,
                                                            smaller_sides=True)
                            fe_output = get_max_surrounding(
                                fe_output,
                                surrounding=4,
                                only_under=True,
                                smaller_sides=False)
                        elif args.dataset == 'tub':
                            fe_output = get_max_surrounding(fe_output,
                                                            surrounding=6,
                                                            only_under=True,
                                                            smaller_sides=True)
                        elif args.dataset == 'ucsd':
                            fe_output = get_max_surrounding(fe_output,
                                                            surrounding=6,
                                                            only_under=True,
                                                            smaller_sides=True)
                        elif args.dataset == 'aicity':
                            fe_output = get_max_surrounding(
                                fe_output,
                                surrounding=6,
                                only_under=False,
                                smaller_sides=False)
                            fe_output = get_max_surrounding(
                                fe_output,
                                surrounding=6,
                                only_under=False,
                                smaller_sides=False)
                        else:
                            print("No maxing exists for this dataset")
                            exit()

                    # Resize and save as numpy
                    cc_output = loi_model.to_orig_size(cc_output)
                    cc_output = cc_output.squeeze().squeeze()
                    cc_output = cc_output.detach().cpu().data.numpy()
                    fe_output = loi_model.to_orig_size(fe_output)
                    fe_output = fe_output.squeeze().permute(1, 2, 0)
                    fe_output = fe_output.detach().cpu().data.numpy()

                    # If in take_image mode we only quickly take a picture of the results!!!
                    # Move to utilities
                    if args.loi_level == 'take_image':
                        dir1 = 'full_imgs/{}/{}-{}/'.format(
                            args.dataset, v_i, capt)
                        capt = capt + 1
                        back = '{}_m{}_{}_{}'.format(
                            args.model, args.loi_maxing, args.frames_between,
                            datetime.now().strftime("%Y%m%d_%H%M%S"))
                        Path(dir1).mkdir(parents=True, exist_ok=True)

                        img = Image.open(video.get_frame_pairs()
                                         [s_i].get_frames(1).get_image_path())
                        img.save('{}orig2.jpg'.format(dir1))

                        density = torch.FloatTensor(
                            density_filter.gaussian_filter_fixed_density(
                                video.get_frame_pairs()[s_i].get_frames(0), 8))
                        density = density.numpy()
                        cc_img = Image.fromarray(density * 255.0 /
                                                 density.max())
                        cc_img = cc_img.convert("L")
                        cc_img.save('{}orig_cc.jpg'.format(dir1))

                        img = Image.open(video.get_frame_pairs()
                                         [s_i].get_frames(0).get_image_path())
                        cc = cc_output
                        fe = fe_output

                        img.save('{}orig.jpg'.format(dir1))

                        cc_img = Image.fromarray(cc_output * 255.0 /
                                                 cc_output.max())
                        cc_img = cc_img.convert("L")
                        cc_img.save('{}cc_{}.jpg'.format(dir1, back))

                        fe_img = Image.fromarray(np.uint8(
                            utils.flo_to_color(fe)),
                                                 mode='RGB')
                        fe_img.save('{}fe_{}.jpg'.format(dir1, back))

                        cc_img = cc_img.convert('RGB')
                        blended = Image.blend(cc_img, fe_img, alpha=0.25)
                        blended.save('{}blend_{}.jpg'.format(dir1, back))
                        continue

                    # Apply ROI's before calculating LOI
                    if args.dataset == 'aicity':
                        cc_output = np.multiply(roi, cc_output)
                        densities = np.multiply(roi, densities)

                    # Extract LOI results
                    if args.loi_level == 'pixel':
                        loi_results = loi_model.pixelwise_forward(
                            cc_output, fe_output)
                    elif args.loi_level == 'region':
                        loi_results = loi_model.regionwise_forward(
                            cc_output, fe_output)
                    elif args.loi_level == 'crossed':
                        loi_results = loi_model.cross_pixelwise_forward(
                            cc_output, fe_output)
                    elif args.loi_level == 'moving_counting':
                        loi_results = ([0], [0])
                        if args.dataset == 'tub':
                            minimum_move = 3
                        elif args.dataset == 'aicity':
                            minimum_move = 6
                        else:
                            print(
                                'This dataset doesnt work with moving counting'
                            )
                            exit()

                        minimum_fe = np.linalg.norm(fe_output,
                                                    axis=2) > minimum_move
                        moving_density = np.multiply(minimum_fe, cc_output)

                        metrics['m_mae'].update(
                            abs(moving_density.sum() - len(
                                frame_pair.get_frames(0).get_centers(
                                    only_moving=True))))
                        metrics['m_mse'].update(
                            math.pow(
                                moving_density.sum() - len(
                                    frame_pair.get_frames(0).get_centers(
                                        only_moving=True)), 2))
                    else:
                        print('Incorrect LOI level')
                        exit()

                    # Keep the time and save it
                    if s_i > 0:
                        metrics['timing'].update(timer.show(False))

                    # ROI information should be saved once per video
                    if l_i == 0:
                        # Based on densities, but due errors these can be of
                        metrics['old_roi_mae'].update(
                            abs((cc_output.sum() - densities.sum()).item()))
                        metrics['old_roi_mse'].update(
                            torch.pow(cc_output.sum() - densities.sum(),
                                      2).item())

                        # Comparing with the real numbers is better for real world applications, but often worse performance
                        metrics['real_roi_mae'].update(
                            abs(cc_output.sum().item() -
                                len(frame_pair.get_frames(0).get_centers())))
                        metrics['real_roi_mse'].update(
                            math.pow(
                                cc_output.sum().item() -
                                len(frame_pair.get_frames(0).get_centers()),
                                2))

                    # @TODO: Fix this to get all totals work like this
                    ucsd_total_count[0].append(sum(loi_results[0]))
                    ucsd_total_count[1].append(sum(loi_results[1]))
                    per_frame[0].append(sum(loi_results[0]))
                    per_frame[1].append(sum(loi_results[1]))

                    # Update GUI
                    pbar.set_description('{} ({}), {} ({})'.format(
                        sum(per_frame[0]), crosses[0], sum(per_frame[1]),
                        crosses[1]))
                    pbar.update(1)

                    # Another video saver, which one to use and which not??
                    if v_i == 0 and l_i == 0:
                        if s_i < 10:
                            img = Image.open(
                                video.get_frame_pairs()[s_i].get_frames(
                                    0).get_image_path())

                            utils.save_loi_sample(
                                "{}_{}_{}".format(v_i, l_i, s_i), img,
                                cc_output, fe_output)

                pbar.close()

                # Non-real line is important for ROI counting, but not LOI counting
                if not real_line:
                    break

                # Last frame is skipped, because we can't predict the one, so fix
                # @TODO fix the UCSD dataset and merge afterwards
                ucsd_total_count[0].append(0.0)
                ucsd_total_count[1].append(0.0)
                per_frame[0].append(0.0)
                per_frame[1].append(0.0)

                print("Timing {}".format(metrics['timing'].avg))

                # truth and predicted for evaluating all metrics
                t_left, t_right = crosses
                p_left, p_right = (sum(per_frame[0]), sum(per_frame[1]))

                mae = abs(t_left - p_left) + abs(t_right - p_right)
                metrics['loi_mae'].update(mae)

                mse = math.pow(t_left - p_left, 2) + math.pow(
                    t_right - p_right, 2)
                metrics['loi_mse'].update(mse)

                percentual_total_mae = (p_left + p_right) / (t_left + t_right)
                metrics['loi_ptmae'].update(percentual_total_mae)

                relative_mae = mae / (t_left + t_right)
                metrics['loi_mape'].update(relative_mae)

                print("LOI performance (MAE: {}, RMSE: {}, MAPE: {})".format(
                    mae, math.sqrt(mse), relative_mae))

                results.append({
                    'vid': v_i,
                    'loi': l_i,
                    'mae': mae,
                    'mse': mse,
                    'ptmae': percentual_total_mae,
                    'rmae': relative_mae
                })

                # @TODO Do this in general!! Not only for Dam, because these results are interesting to compare
                if args.dataset == 'dam':
                    results = {'per_frame': per_frame}

                    with open('dam_results_{}_{}.json'.format(v_i, l_i),
                              'w') as outfile:
                        json.dump(results, outfile)

        if args.loi_level == 'take_image':
            return

        # @TODO Move this to utils!!
        if args.dataset == 'ucsd':
            ucsd_total_gt = ucsdpeds.load_countings('data/ucsdpeds')
            ucsd_total_count2 = ucsd_total_count
            ucsd_total_count = [[], []]

            for i in range(len(ucsd_total_count2[0])):
                for _ in range(args.frames_between):
                    ucsd_total_count[0].append(ucsd_total_count2[0][i] /
                                               args.frames_between)
                    ucsd_total_count[1].append(ucsd_total_count2[1][i] /
                                               args.frames_between)

            wmae = [[], []]
            tmae = [[], []]
            imae = [[], []]

            for i, _ in enumerate(ucsd_total_count[0]):
                imae[0].append(
                    abs(ucsd_total_count[0][i] - ucsd_total_gt[0][i]))
                imae[1].append(
                    abs(ucsd_total_count[1][i] - ucsd_total_gt[1][i]))

                if i >= 600:
                    tmae[0].append(
                        abs(
                            sum(ucsd_total_count[0][600:i + 1]) -
                            sum(ucsd_total_gt[0][600:i + 1])))
                    tmae[1].append(
                        abs(
                            sum(ucsd_total_count[1][600:i + 1]) -
                            sum(ucsd_total_gt[1][600:i + 1])))

                    if i + 100 < 1200:
                        wmae[0].append(
                            abs(
                                sum(ucsd_total_count[0][i:i + 100]) -
                                sum(ucsd_total_gt[0][i:i + 100])))
                        wmae[1].append(
                            abs(
                                sum(ucsd_total_count[1][i:i + 100]) -
                                sum(ucsd_total_gt[1][i:i + 100])))
                else:
                    tmae[0].append(
                        abs(
                            sum(ucsd_total_count[0][:i + 1]) -
                            sum(ucsd_total_gt[0][:i + 1])))
                    tmae[1].append(
                        abs(
                            sum(ucsd_total_count[1][:i + 1]) -
                            sum(ucsd_total_gt[1][:i + 1])))

                    if i + 100 < 600:
                        wmae[0].append(
                            abs(
                                sum(ucsd_total_count[0][i:i + 100]) -
                                sum(ucsd_total_gt[0][i:i + 100])))
                        wmae[1].append(
                            abs(
                                sum(ucsd_total_count[1][i:i + 100]) -
                                sum(ucsd_total_gt[1][i:i + 100])))

            print("UCSD results, total error left: {}, right: {}".format(
                abs(sum(ucsd_total_count[0]) - sum(ucsd_total_gt[0])),
                abs(sum(ucsd_total_count[1]) - sum(ucsd_total_gt[1]))))
            print("IMAE: {} | {}".format(
                sum(imae[0]) / len(imae[0]),
                sum(imae[1]) / len(imae[1])))
            print("TMAE: {} | {}".format(
                sum(tmae[0]) / len(tmae[0]),
                sum(tmae[1]) / len(tmae[1])))
            print("WMAE: {} | {}".format(
                sum(wmae[0]) / len(wmae[0]),
                sum(wmae[1]) / len(wmae[1])))
        # END of UCSD special testing

        # Save all results. Some won't work per time, but are then often 0.
        # @TODO add in README which results are when valid
        results = {
            'loi_mae': metrics['loi_mae'].avg,
            'loi_mse': metrics['loi_mse'].avg,
            'loi_rmae': metrics['loi_rmae'].avg,
            'loi_ptmae': metrics['loi_ptmae'].avg,
            'loi_mape': metrics['loi_mape'].avg,
            'old_roi_mae': metrics['old_roi_mae'].avg,
            'old_roi_mse': metrics['old_roi_mse'].avg,
            'real_roi_mae': metrics['real_roi_mae'].avg,
            'real_roi_mse': metrics['real_roi_mse'].avg,
            'moving_mae': metrics['m_mae'].avg,
            'moving_mse': metrics['m_mse'].avg,
            'timing': metrics['timing'].avg,
            'per_vid': results
        }
        outname = 'new_{}_{}_{}_{}_{}_{}'.format(
            args.dataset, args.model, args.eval_method, args.loi_level,
            args.loi_maxing,
            datetime.now().strftime("%Y%m%d_%H%M%S"))
        with open('loi_results/{}.json'.format(outname), 'w') as outfile:
            json.dump(results, outfile)

        # Print simple results
        print("MAE: {}, MSE: {}, MAPE: {}".format(metrics['loi_mae'].avg,
                                                  metrics['loi_mse'].avg,
                                                  metrics['loi_mape'].avg))

        return results