예제 #1
0
def test(hparams, model, dataset, pickle_path):
    PAD_id = hparams["PAD_id"]
    EOS_id = hparams["EOS_id"]
    vocab = dataset.vocab
    model.eval()

    dial_list = []
    pbar = tqdm.tqdm(enumerate(dataset),total = len(dataset))
    for idx, data in pbar:
        data = collate_fn([data])
        pbar.set_description("Dial {}".format(idx + 1))
        inf_uttrs, decoded_uttrs, likelihoods = inference(hparams,model,vocab,data)

        dial_list += [{
            "id":idx,
            "src":[
                " ".join([vocab["itow"].get(w,"<unk>") for w in s
                          if w != PAD_id and w != EOS_id])
                for s in data["src"][0].numpy()
            ],
            "tgt":" ".join([
                vocab["itow"].get(w,"<unk>")
                for w in data["tgt"][0].numpy()
                if w != PAD_id and w != EOS_id
            ]),
            "tgt_id":[[id for id in data["tgt"][0].tolist() if id != PAD_id and id != EOS_id]],
            "inf_id":[[id for id in res if id != PAD_id and id != EOS_id] for res in inf_uttrs],
            "inf" : decoded_uttrs,
            "likelihood":likelihoods,
        }]
        with open(pickle_path,mode = "wb") as f:
            pickle.dump(dial_list,f)
예제 #2
0
def chat(hparams,model,vocab):
    model.eval()

    MAX_INPUT_LEN = hparams["MAX_DIAL_LEN"] - 1
    src_list = []
    while(1):
        src_sent = input("> ")
        loginfo_and_print(logger, "User: {}".format(src_sent))
        if src_sent == ":q" or src_sent == ":quit":
            break
        src_list += [src_sent]
        print("src_list",src_list)
        src = [
            [vocab["wtoi"].get(w.lower(),hparams["UNK_id"]) for w in word_tokenize(s)] + [hparams["EOS_id"]]
            for s in src_list[-MAX_INPUT_LEN:]
        ]
        data = collate_fn([{"src":src}])
        _, decoded_uttrs,_ = inference(hparams,model,vocab,data,chat_mode=True)

        src_list += [decoded_uttrs[0]]
        loginfo_and_print(logger, "Bot : {}".format(decoded_uttrs[0]))
예제 #3
0
    def __init__(self, model, vocab):

        assert isinstance(model, dict) or isinstance(model, str)
        assert isinstance(vocab, tuple) or isinstance(vocab, str)

        # dataset
        logger.info('-' * 100)
        logger.info('Loading training and validation dataset')
        self.dataset = data.CodePtrDataset(mode='test')
        self.dataset_size = len(self.dataset)
        logger.info('Size of training dataset: {}'.format(self.dataset_size))

        logger.info('The dataset are successfully loaded')

        self.dataloader = DataLoader(dataset=self.dataset,
                                     batch_size=config.test_batch_size,
                                     collate_fn=lambda *args: utils.collate_fn(args,
                                                                               source_vocab=self.source_vocab,
                                                                               code_vocab=self.code_vocab,
                                                                               ast_vocab=self.ast_vocab,
                                                                               nl_vocab=self.nl_vocab,
                                                                               raw_nl=True))

        # vocab
        logger.info('-' * 100)
        if isinstance(vocab, tuple):
            logger.info('Vocabularies are passed from parameters')
            assert len(vocab) == 4
            self.source_vocab, self.code_vocab, self.ast_vocab, self.nl_vocab = vocab
        else:
            logger.info('Vocabularies are read from dir: {}'.format(vocab))
            self.source_vocab = utils.load_vocab(vocab, 'source')
            self.code_vocab = utils.load_vocab(vocab, 'code')
            self.ast_vocab = utils.load_vocab(vocab, 'ast')
            self.nl_vocab = utils.load_vocab(vocab, 'nl')

        # vocabulary
        self.source_vocab_size = len(self.source_vocab)
        self.code_vocab_size = len(self.code_vocab)
        self.ast_vocab_size = len(self.ast_vocab)
        self.nl_vocab_size = len(self.nl_vocab)

        logger.info('Size of source vocabulary: {} -> {}'.format(self.source_vocab.origin_size, self.source_vocab_size))
        logger.info('Size of code vocabulary: {} -> {}'.format(self.code_vocab.origin_size, self.code_vocab_size))
        logger.info('Size of ast vocabulary: {}'.format(self.ast_vocab_size))
        logger.info('Size of nl vocabulary: {} -> {}'.format(self.nl_vocab.origin_size, self.nl_vocab_size))

        logger.info('Vocabularies are successfully built')

        # model
        logger.info('-' * 100)
        logger.info('Building model')
        self.model = models.Model(source_vocab_size=self.source_vocab_size,
                                  code_vocab_size=self.code_vocab_size,
                                  ast_vocab_size=self.ast_vocab_size,
                                  nl_vocab_size=self.nl_vocab_size,
                                  is_eval=True,
                                  model=model)
        # model device
        logger.info('Model device: {}'.format(next(self.model.parameters()).device))
        # log model statistic
        logger.info('Trainable parameters: {}'.format(utils.human_format(utils.count_params(self.model))))
예제 #4
0
def main():
    print(f"Beginning training at {time.time()}...")

    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-m",
        "--model",
        help="Name of the model",
        choices=MODELS,
        default=TRANSFORMER,
    )
    args = parser.parse_args()
    model_type = args.model

    if utils.is_spot_instance():
        signal.signal(signal.SIGTERM, utils.sigterm_handler)

    # For laptop & deep learning rig testing on the same codebase
    if not torch.cuda.is_available():
        multiprocessing.set_start_method("spawn", True)
        device = torch.device("cpu")
        num_workers = 0
        max_elements = 5
        save_checkpoints = False
    else:
        device = torch.device("cuda")
        # https://github.com/facebookresearch/maskrcnn-benchmark/issues/195
        num_workers = 0
        max_elements = None
        save_checkpoints = True

    deterministic = True
    if deterministic:
        seed = 0
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        random.seed(seed)
        np.random.seed(seed)
        os.environ["PYTHONHASHSEED"] = str(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    shuffle = not deterministic

    # Hyperparameters
    batch_size = 1024 if torch.cuda.device_count() > 1 else 8
    lr = 6e-4
    warmup_lr = 6e-6  # TODO: Refactor into custom optimizer class
    warmup_interval = None  # 10000  # or None
    beta_coeff_low = 0.9
    beta_coeff_high = 0.995
    eps = 1e-9
    smoothing = False
    weight_sharing = True

    # Config
    unique_id = f"6-24-20_{model_type}1"
    exp = "math_112m_bs128"
    name = f"{exp}_{unique_id}"
    run_max_batches = 500000  # Defined in paper
    should_restore_checkpoint = True
    pin_memory = True

    print("Model name:", name)
    print(
        f"Batch size: {batch_size}. Learning rate: {lr}. Warmup_lr: {warmup_lr}. Warmup interval: {warmup_interval}. B low {beta_coeff_low}. B high {beta_coeff_high}. eps {eps}. Smooth: {smoothing}"
    )
    print("Deterministic:", deterministic)
    print("Device:", device)
    print("Should restore checkpoint:", should_restore_checkpoint)

    model = utils.build_model(model_type, weight_sharing)

    optimizer = optim.Adam(
        model.parameters(),
        lr=lr if warmup_interval is None else warmup_lr,
        betas=(beta_coeff_low, beta_coeff_high),
        eps=eps,
    )

    tb = Tensorboard(exp, unique_name=unique_id)

    # Run state
    start_batch = 0
    start_epoch = 0
    run_batches = 0
    total_loss = 0
    n_char_total = 0
    n_char_correct = 0

    if should_restore_checkpoint:
        cp_path = f"checkpoints/{name}_latest_checkpoint.pth"
        # cp_path = "checkpoint_b109375_e0.pth"

        state = restore_checkpoint(
            cp_path,
            model_type=model_type,
            model=model,
            optimizer=optimizer,
        )

        if state is not None:
            start_epoch = state["epoch"]
            # best_acc = state["acc"]
            # best_loss = state["loss"]
            run_batches = state["run_batches"]
            lr = state["lr"]
            for param_group in optimizer.param_groups:
                param_group["lr"] = lr

            start_batch = state.get("start_batch", None) or 0
            total_loss = state.get("total_loss", None) or 0
            n_char_total = state.get("n_char_total", None) or 0
            n_char_correct = state.get("n_char_correct", None) or 0

            # Need to move optimizer state to GPU memory
            if torch.cuda.is_available():
                for state in optimizer.state.values():
                    for k, v in state.items():
                        if torch.is_tensor(v):
                            state[k] = v.cuda()

            print(f"Setting lr to {lr}")
            print("Loaded checkpoint successfully")

    print("start_epoch", start_epoch)
    print("start_batch", start_batch)
    print("total_loss", total_loss)

    if torch.cuda.device_count() > 1:
        print("Using", torch.cuda.device_count(), "GPUs!")
        model = torch.nn.DataParallel(model)

    model = model.to(device)

    dataset_path = "./mathematics_dataset-v1.0"
    mini_dataset_path = "./mini_dataset"
    if not os.path.isdir(Path(dataset_path)):
        print(
            "Full dataset not detected. Using backup mini dataset for testing. See repo for instructions on downloading full dataset."
        )
        dataset_path = mini_dataset_path

    ds_train = FullDatasetManager(
        dataset_path,
        max_elements=max_elements,
        deterministic=deterministic,
        start_epoch=start_epoch,
        start_datapoint=start_batch * batch_size,
    )
    print("Train size:", len(ds_train))

    ds_interpolate = FullDatasetManager(
        dataset_path,
        max_elements=max_elements,
        deterministic=deterministic,
        start_epoch=start_epoch,
        mode="interpolate",
    )
    print("Interpolate size:", len(ds_interpolate))

    ds_extrapolate = FullDatasetManager(
        dataset_path,
        max_elements=max_elements,
        deterministic=deterministic,
        start_epoch=start_epoch,
        mode="extrapolate",
    )
    print("Extrapolate size:", len(ds_extrapolate))

    collate_fn = utils.collate_fn(model_type)

    train_loader = data.DataLoader(
        ds_train,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
        collate_fn=collate_fn,
        pin_memory=pin_memory,
    )

    interpolate_loader = data.DataLoader(
        ds_interpolate,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
        collate_fn=collate_fn,
        pin_memory=pin_memory,
    )

    extrapolate_loader = data.DataLoader(
        ds_extrapolate,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
        collate_fn=collate_fn,
        pin_memory=pin_memory,
    )

    model_process.train(
        name=name,
        model=model,
        training_data=train_loader,
        optimizer=optimizer,
        device=device,
        epochs=1000,  # Not relevant, will get ended before this due to max_b
        tb=tb,
        run_max_batches=run_max_batches,
        validation_data=None,
        start_epoch=start_epoch,
        start_batch=start_batch,
        total_loss=total_loss,
        n_char_total=n_char_total,
        n_char_correct=n_char_correct,
        run_batches=run_batches,
        interpolate_data=interpolate_loader,
        extrapolate_data=extrapolate_loader,
        checkpoint=save_checkpoints,
        lr=lr,
        warmup_lr=warmup_lr,
        warmup_interval=warmup_interval,
        smoothing=smoothing,
    )
예제 #5
0
 def test_simple(self):
     batch: List[Tuple[str, int]] = [('a', 1)]
     ret = utils.collate_fn(batch)
     self.assertEqual(ret, (('a', ), (1, )))
예제 #6
0
def copypaste_collate_fn(batch):
    copypaste = SimpleCopyPaste(blending=True, resize_interpolation=InterpolationMode.BILINEAR)
    return copypaste(*utils.collate_fn(batch))
예제 #7
0
    def __getitem__(self, i):
        data0 = super().__getitem__(i, pad=False, add_special_toks=False)
        data0['seq_type'] = 0

        data = [data0]
        vid_id = data0['vid_ids']

        eval_mask = self.test_masking_policy is not None and self.test_masking_policy != 'random'

        if eval_mask:
            vid_id, act_id, frame_id = self.frames[data0['indices']]
            item_p_mask, pos_list = compute_item_mask(
                self.actions[vid_id][act_id], data0['text'], self.split_data,
                self.test_masking_policy, self.tokenizer)

            tok_groups = []
            tok_group_labels = []
            tok_starts_ends = []
            try:
                for pos, l, grp in pos_list:
                    tok_groups.append(l)
                    tok_group_labels.append(grp)
                    tok_starts_ends.append((pos, pos + l))
                tgt_token_ids = list(
                    set(data0['text'][item_p_mask.bool()].tolist()))
                n_neg = len(
                    tgt_token_ids
                ) * self.negs_per_pos  # farm as many negatives as positives
            except:
                n_neg = 0
                tgt_token_ids = []

        else:
            n_pos = random.randint(self.min_positives, self.max_positives)
            n_neg = random.randint(self.min_negatives, self.max_negatives)

            tgt_token_ids = set(
                filter(
                    lambda t: t not in self.tokenizer.convert_tokens_to_ids(
                        self.tokenizer.all_special_tokens),
                    data0['text'].tolist()))
            tgt_token_ids = random.sample(
                tgt_token_ids,
                n_pos) if len(tgt_token_ids) > n_pos else tgt_token_ids

        pos_txt_ids = set()
        cnt_positives = 0
        for idx, token_id in enumerate(tgt_token_ids):
            if token_id in pos_txt_ids:
                continue
            if cnt_positives >= self.max_positives:
                # tgt_token_ids[idx] = 0
                break
            # select positive
            candidate_positives = self.frame_words[token_id]
            if not len(candidate_positives):
                tgt_token_ids[idx] = 0
                continue
            # remove all frames belonging to the same video
            candidate_positives_same = [
                i for i, _, _ in filter(lambda x: x[1] == vid_id,
                                        candidate_positives)
            ]
            candidate_positives_different = [
                i for i, _, _ in filter(lambda x: x[1] != vid_id,
                                        candidate_positives)
            ]
            if len(candidate_positives_different) > 0:
                j = random.choice(candidate_positives_different)
            else:
                # only if that word does not exist in any other example, use element from the same action as positive
                j = random.choice(candidate_positives_same)
            data_pos = super().__getitem__(j,
                                           pad=False,
                                           add_special_toks=False)
            if data_pos['imgs'] is None:
                tgt_token_ids[idx] = 0
                continue
            data_pos['seq_type'] = 1
            data.append(data_pos)
            pos_txt_ids.update(data_pos['text'].tolist())
            cnt_positives += 1

        n_neg = min(n_neg, self.max_negatives)
        n_neg = max(n_neg, self.min_negatives)
        while n_neg:
            j = random.randint(0, self.__len__() - 1)
            data_neg = super().__getitem__(j,
                                           pad=False,
                                           add_special_toks=False)
            if data_neg['imgs'] is None or set(
                    data_neg['text'].tolist()) & set(tgt_token_ids):
                continue
            data_neg['seq_type'] = -1
            data.append(data_neg)
            n_neg -= 1

        random.shuffle(data)

        collated_data = collate_fn(data, cat_tensors=True)
        collated_data['target_token_ids'] = torch.LongTensor(
            list(tgt_token_ids))
        collated_data['num_seqs'] = len(data)
        collated_data['num_tgt_toks'] = len(tgt_token_ids)
        if eval_mask:
            collated_data['tok_groups'] = tok_groups
            collated_data['tok_group_labels'] = tok_group_labels
        else:
            collated_data['tok_groups'] = []
            collated_data['tok_group_labels'] = []
        return collated_data
예제 #8
0
vocab = Vocab(args.vocab_file_path)
assert len(vocab) > 0

# Setup GPU
use_cuda = True if args.cuda and torch.cuda.is_available() else False
assert use_cuda  # Trust me, you don't want to train this model on a cpu.
device = torch.device("cuda" if use_cuda else "cpu")

transform_imgs = resnet_img_transformation(args.img_crop_size)

# Creating the data loader
train_loader = DataLoader(
    Pix2CodeDataset(args.data_path, args.split,
                    vocab, transform=transform_imgs),
    batch_size=args.batch_size,
    collate_fn=lambda data: collate_fn(data, vocab=vocab),
    pin_memory=True if use_cuda else False,
    num_workers=4,
    drop_last=True)
print("Created data loader")

# Creating the models
embed_size = 256
hidden_size = 512
num_layers = 1
lr = args.lr

encoder = Encoder(embed_size)
decoder = Decoder(embed_size, hidden_size, len(vocab), num_layers)

encoder = encoder.to(device)
예제 #9
0
파일: train.py 프로젝트: NougatCA/CodePtr
    def __init__(self):

        # dataset
        logger.info('-' * 100)
        logger.info('Loading training and validation dataset')
        self.dataset = data.CodePtrDataset(mode='train')
        self.dataset_size = len(self.dataset)
        logger.info('Size of training dataset: {}'.format(self.dataset_size))
        self.dataloader = DataLoader(dataset=self.dataset,
                                     batch_size=config.batch_size,
                                     shuffle=True,
                                     collate_fn=lambda *args: utils.collate_fn(
                                         args,
                                         source_vocab=self.source_vocab,
                                         code_vocab=self.code_vocab,
                                         ast_vocab=self.ast_vocab,
                                         nl_vocab=self.nl_vocab))

        # valid dataset
        self.valid_dataset = data.CodePtrDataset(mode='valid')
        self.valid_dataset_size = len(self.valid_dataset)
        self.valid_dataloader = DataLoader(
            dataset=self.valid_dataset,
            batch_size=config.valid_batch_size,
            collate_fn=lambda *args: utils.collate_fn(
                args,
                source_vocab=self.source_vocab,
                code_vocab=self.code_vocab,
                ast_vocab=self.ast_vocab,
                nl_vocab=self.nl_vocab))
        logger.info('Size of validation dataset: {}'.format(
            self.valid_dataset_size))
        logger.info('The dataset are successfully loaded')

        # vocab
        logger.info('-' * 100)
        logger.info('Building vocabularies')

        sources, codes, asts, nls = self.dataset.get_dataset()

        self.source_vocab = utils.build_word_vocab(
            dataset=sources,
            vocab_name='source',
            ignore_case=True,
            max_vocab_size=config.source_vocab_size,
            save_dir=config.vocab_root)
        self.source_vocab_size = len(self.source_vocab)
        logger.info('Size of source vocab: {} -> {}'.format(
            self.source_vocab.origin_size, self.source_vocab_size))

        self.code_vocab = utils.build_word_vocab(
            dataset=codes,
            vocab_name='code',
            ignore_case=True,
            max_vocab_size=config.code_vocab_size,
            save_dir=config.vocab_root)
        self.code_vocab_size = len(self.code_vocab)
        logger.info('Size of code vocab: {} -> {}'.format(
            self.code_vocab.origin_size, self.code_vocab_size))

        self.ast_vocab = utils.build_word_vocab(dataset=asts,
                                                vocab_name='ast',
                                                ignore_case=True,
                                                save_dir=config.vocab_root)
        self.ast_vocab_size = len(self.ast_vocab)
        logger.info('Size of ast vocab: {}'.format(self.ast_vocab_size))

        self.nl_vocab = utils.build_word_vocab(
            dataset=nls,
            vocab_name='nl',
            ignore_case=True,
            max_vocab_size=config.nl_vocab_size,
            save_dir=config.vocab_root)
        self.nl_vocab_size = len(self.nl_vocab)
        logger.info('Size of nl vocab: {} -> {}'.format(
            self.nl_vocab.origin_size, self.nl_vocab_size))

        logger.info('Vocabularies are successfully built')

        # model
        logger.info('-' * 100)
        logger.info('Building the model')
        self.model = models.Model(source_vocab_size=self.source_vocab_size,
                                  code_vocab_size=self.code_vocab_size,
                                  ast_vocab_size=self.ast_vocab_size,
                                  nl_vocab_size=self.nl_vocab_size)
        # model device
        logger.info('Model device: {}'.format(
            next(self.model.parameters()).device))
        # log model statistic
        logger.info('Trainable parameters: {}'.format(
            utils.human_format(utils.count_params(self.model))))

        # optimizer
        self.optimizer = Adam([
            {
                'params': self.model.parameters(),
                'lr': config.learning_rate
            },
        ])

        self.criterion = nn.CrossEntropyLoss(
            ignore_index=self.nl_vocab.get_pad_index())

        if config.use_lr_decay:
            self.lr_scheduler = lr_scheduler.StepLR(self.optimizer,
                                                    step_size=1,
                                                    gamma=config.lr_decay_rate)

        # early stopping
        self.early_stopping = None
        if config.use_early_stopping:
            self.early_stopping = utils.EarlyStopping(
                patience=config.early_stopping_patience, high_record=False)