def test(hparams, model, dataset, pickle_path): PAD_id = hparams["PAD_id"] EOS_id = hparams["EOS_id"] vocab = dataset.vocab model.eval() dial_list = [] pbar = tqdm.tqdm(enumerate(dataset),total = len(dataset)) for idx, data in pbar: data = collate_fn([data]) pbar.set_description("Dial {}".format(idx + 1)) inf_uttrs, decoded_uttrs, likelihoods = inference(hparams,model,vocab,data) dial_list += [{ "id":idx, "src":[ " ".join([vocab["itow"].get(w,"<unk>") for w in s if w != PAD_id and w != EOS_id]) for s in data["src"][0].numpy() ], "tgt":" ".join([ vocab["itow"].get(w,"<unk>") for w in data["tgt"][0].numpy() if w != PAD_id and w != EOS_id ]), "tgt_id":[[id for id in data["tgt"][0].tolist() if id != PAD_id and id != EOS_id]], "inf_id":[[id for id in res if id != PAD_id and id != EOS_id] for res in inf_uttrs], "inf" : decoded_uttrs, "likelihood":likelihoods, }] with open(pickle_path,mode = "wb") as f: pickle.dump(dial_list,f)
def chat(hparams,model,vocab): model.eval() MAX_INPUT_LEN = hparams["MAX_DIAL_LEN"] - 1 src_list = [] while(1): src_sent = input("> ") loginfo_and_print(logger, "User: {}".format(src_sent)) if src_sent == ":q" or src_sent == ":quit": break src_list += [src_sent] print("src_list",src_list) src = [ [vocab["wtoi"].get(w.lower(),hparams["UNK_id"]) for w in word_tokenize(s)] + [hparams["EOS_id"]] for s in src_list[-MAX_INPUT_LEN:] ] data = collate_fn([{"src":src}]) _, decoded_uttrs,_ = inference(hparams,model,vocab,data,chat_mode=True) src_list += [decoded_uttrs[0]] loginfo_and_print(logger, "Bot : {}".format(decoded_uttrs[0]))
def __init__(self, model, vocab): assert isinstance(model, dict) or isinstance(model, str) assert isinstance(vocab, tuple) or isinstance(vocab, str) # dataset logger.info('-' * 100) logger.info('Loading training and validation dataset') self.dataset = data.CodePtrDataset(mode='test') self.dataset_size = len(self.dataset) logger.info('Size of training dataset: {}'.format(self.dataset_size)) logger.info('The dataset are successfully loaded') self.dataloader = DataLoader(dataset=self.dataset, batch_size=config.test_batch_size, collate_fn=lambda *args: utils.collate_fn(args, source_vocab=self.source_vocab, code_vocab=self.code_vocab, ast_vocab=self.ast_vocab, nl_vocab=self.nl_vocab, raw_nl=True)) # vocab logger.info('-' * 100) if isinstance(vocab, tuple): logger.info('Vocabularies are passed from parameters') assert len(vocab) == 4 self.source_vocab, self.code_vocab, self.ast_vocab, self.nl_vocab = vocab else: logger.info('Vocabularies are read from dir: {}'.format(vocab)) self.source_vocab = utils.load_vocab(vocab, 'source') self.code_vocab = utils.load_vocab(vocab, 'code') self.ast_vocab = utils.load_vocab(vocab, 'ast') self.nl_vocab = utils.load_vocab(vocab, 'nl') # vocabulary self.source_vocab_size = len(self.source_vocab) self.code_vocab_size = len(self.code_vocab) self.ast_vocab_size = len(self.ast_vocab) self.nl_vocab_size = len(self.nl_vocab) logger.info('Size of source vocabulary: {} -> {}'.format(self.source_vocab.origin_size, self.source_vocab_size)) logger.info('Size of code vocabulary: {} -> {}'.format(self.code_vocab.origin_size, self.code_vocab_size)) logger.info('Size of ast vocabulary: {}'.format(self.ast_vocab_size)) logger.info('Size of nl vocabulary: {} -> {}'.format(self.nl_vocab.origin_size, self.nl_vocab_size)) logger.info('Vocabularies are successfully built') # model logger.info('-' * 100) logger.info('Building model') self.model = models.Model(source_vocab_size=self.source_vocab_size, code_vocab_size=self.code_vocab_size, ast_vocab_size=self.ast_vocab_size, nl_vocab_size=self.nl_vocab_size, is_eval=True, model=model) # model device logger.info('Model device: {}'.format(next(self.model.parameters()).device)) # log model statistic logger.info('Trainable parameters: {}'.format(utils.human_format(utils.count_params(self.model))))
def main(): print(f"Beginning training at {time.time()}...") parser = argparse.ArgumentParser() parser.add_argument( "-m", "--model", help="Name of the model", choices=MODELS, default=TRANSFORMER, ) args = parser.parse_args() model_type = args.model if utils.is_spot_instance(): signal.signal(signal.SIGTERM, utils.sigterm_handler) # For laptop & deep learning rig testing on the same codebase if not torch.cuda.is_available(): multiprocessing.set_start_method("spawn", True) device = torch.device("cpu") num_workers = 0 max_elements = 5 save_checkpoints = False else: device = torch.device("cuda") # https://github.com/facebookresearch/maskrcnn-benchmark/issues/195 num_workers = 0 max_elements = None save_checkpoints = True deterministic = True if deterministic: seed = 0 torch.manual_seed(seed) torch.cuda.manual_seed(seed) random.seed(seed) np.random.seed(seed) os.environ["PYTHONHASHSEED"] = str(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False shuffle = not deterministic # Hyperparameters batch_size = 1024 if torch.cuda.device_count() > 1 else 8 lr = 6e-4 warmup_lr = 6e-6 # TODO: Refactor into custom optimizer class warmup_interval = None # 10000 # or None beta_coeff_low = 0.9 beta_coeff_high = 0.995 eps = 1e-9 smoothing = False weight_sharing = True # Config unique_id = f"6-24-20_{model_type}1" exp = "math_112m_bs128" name = f"{exp}_{unique_id}" run_max_batches = 500000 # Defined in paper should_restore_checkpoint = True pin_memory = True print("Model name:", name) print( f"Batch size: {batch_size}. Learning rate: {lr}. Warmup_lr: {warmup_lr}. Warmup interval: {warmup_interval}. B low {beta_coeff_low}. B high {beta_coeff_high}. eps {eps}. Smooth: {smoothing}" ) print("Deterministic:", deterministic) print("Device:", device) print("Should restore checkpoint:", should_restore_checkpoint) model = utils.build_model(model_type, weight_sharing) optimizer = optim.Adam( model.parameters(), lr=lr if warmup_interval is None else warmup_lr, betas=(beta_coeff_low, beta_coeff_high), eps=eps, ) tb = Tensorboard(exp, unique_name=unique_id) # Run state start_batch = 0 start_epoch = 0 run_batches = 0 total_loss = 0 n_char_total = 0 n_char_correct = 0 if should_restore_checkpoint: cp_path = f"checkpoints/{name}_latest_checkpoint.pth" # cp_path = "checkpoint_b109375_e0.pth" state = restore_checkpoint( cp_path, model_type=model_type, model=model, optimizer=optimizer, ) if state is not None: start_epoch = state["epoch"] # best_acc = state["acc"] # best_loss = state["loss"] run_batches = state["run_batches"] lr = state["lr"] for param_group in optimizer.param_groups: param_group["lr"] = lr start_batch = state.get("start_batch", None) or 0 total_loss = state.get("total_loss", None) or 0 n_char_total = state.get("n_char_total", None) or 0 n_char_correct = state.get("n_char_correct", None) or 0 # Need to move optimizer state to GPU memory if torch.cuda.is_available(): for state in optimizer.state.values(): for k, v in state.items(): if torch.is_tensor(v): state[k] = v.cuda() print(f"Setting lr to {lr}") print("Loaded checkpoint successfully") print("start_epoch", start_epoch) print("start_batch", start_batch) print("total_loss", total_loss) if torch.cuda.device_count() > 1: print("Using", torch.cuda.device_count(), "GPUs!") model = torch.nn.DataParallel(model) model = model.to(device) dataset_path = "./mathematics_dataset-v1.0" mini_dataset_path = "./mini_dataset" if not os.path.isdir(Path(dataset_path)): print( "Full dataset not detected. Using backup mini dataset for testing. See repo for instructions on downloading full dataset." ) dataset_path = mini_dataset_path ds_train = FullDatasetManager( dataset_path, max_elements=max_elements, deterministic=deterministic, start_epoch=start_epoch, start_datapoint=start_batch * batch_size, ) print("Train size:", len(ds_train)) ds_interpolate = FullDatasetManager( dataset_path, max_elements=max_elements, deterministic=deterministic, start_epoch=start_epoch, mode="interpolate", ) print("Interpolate size:", len(ds_interpolate)) ds_extrapolate = FullDatasetManager( dataset_path, max_elements=max_elements, deterministic=deterministic, start_epoch=start_epoch, mode="extrapolate", ) print("Extrapolate size:", len(ds_extrapolate)) collate_fn = utils.collate_fn(model_type) train_loader = data.DataLoader( ds_train, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, ) interpolate_loader = data.DataLoader( ds_interpolate, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, ) extrapolate_loader = data.DataLoader( ds_extrapolate, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, ) model_process.train( name=name, model=model, training_data=train_loader, optimizer=optimizer, device=device, epochs=1000, # Not relevant, will get ended before this due to max_b tb=tb, run_max_batches=run_max_batches, validation_data=None, start_epoch=start_epoch, start_batch=start_batch, total_loss=total_loss, n_char_total=n_char_total, n_char_correct=n_char_correct, run_batches=run_batches, interpolate_data=interpolate_loader, extrapolate_data=extrapolate_loader, checkpoint=save_checkpoints, lr=lr, warmup_lr=warmup_lr, warmup_interval=warmup_interval, smoothing=smoothing, )
def test_simple(self): batch: List[Tuple[str, int]] = [('a', 1)] ret = utils.collate_fn(batch) self.assertEqual(ret, (('a', ), (1, )))
def copypaste_collate_fn(batch): copypaste = SimpleCopyPaste(blending=True, resize_interpolation=InterpolationMode.BILINEAR) return copypaste(*utils.collate_fn(batch))
def __getitem__(self, i): data0 = super().__getitem__(i, pad=False, add_special_toks=False) data0['seq_type'] = 0 data = [data0] vid_id = data0['vid_ids'] eval_mask = self.test_masking_policy is not None and self.test_masking_policy != 'random' if eval_mask: vid_id, act_id, frame_id = self.frames[data0['indices']] item_p_mask, pos_list = compute_item_mask( self.actions[vid_id][act_id], data0['text'], self.split_data, self.test_masking_policy, self.tokenizer) tok_groups = [] tok_group_labels = [] tok_starts_ends = [] try: for pos, l, grp in pos_list: tok_groups.append(l) tok_group_labels.append(grp) tok_starts_ends.append((pos, pos + l)) tgt_token_ids = list( set(data0['text'][item_p_mask.bool()].tolist())) n_neg = len( tgt_token_ids ) * self.negs_per_pos # farm as many negatives as positives except: n_neg = 0 tgt_token_ids = [] else: n_pos = random.randint(self.min_positives, self.max_positives) n_neg = random.randint(self.min_negatives, self.max_negatives) tgt_token_ids = set( filter( lambda t: t not in self.tokenizer.convert_tokens_to_ids( self.tokenizer.all_special_tokens), data0['text'].tolist())) tgt_token_ids = random.sample( tgt_token_ids, n_pos) if len(tgt_token_ids) > n_pos else tgt_token_ids pos_txt_ids = set() cnt_positives = 0 for idx, token_id in enumerate(tgt_token_ids): if token_id in pos_txt_ids: continue if cnt_positives >= self.max_positives: # tgt_token_ids[idx] = 0 break # select positive candidate_positives = self.frame_words[token_id] if not len(candidate_positives): tgt_token_ids[idx] = 0 continue # remove all frames belonging to the same video candidate_positives_same = [ i for i, _, _ in filter(lambda x: x[1] == vid_id, candidate_positives) ] candidate_positives_different = [ i for i, _, _ in filter(lambda x: x[1] != vid_id, candidate_positives) ] if len(candidate_positives_different) > 0: j = random.choice(candidate_positives_different) else: # only if that word does not exist in any other example, use element from the same action as positive j = random.choice(candidate_positives_same) data_pos = super().__getitem__(j, pad=False, add_special_toks=False) if data_pos['imgs'] is None: tgt_token_ids[idx] = 0 continue data_pos['seq_type'] = 1 data.append(data_pos) pos_txt_ids.update(data_pos['text'].tolist()) cnt_positives += 1 n_neg = min(n_neg, self.max_negatives) n_neg = max(n_neg, self.min_negatives) while n_neg: j = random.randint(0, self.__len__() - 1) data_neg = super().__getitem__(j, pad=False, add_special_toks=False) if data_neg['imgs'] is None or set( data_neg['text'].tolist()) & set(tgt_token_ids): continue data_neg['seq_type'] = -1 data.append(data_neg) n_neg -= 1 random.shuffle(data) collated_data = collate_fn(data, cat_tensors=True) collated_data['target_token_ids'] = torch.LongTensor( list(tgt_token_ids)) collated_data['num_seqs'] = len(data) collated_data['num_tgt_toks'] = len(tgt_token_ids) if eval_mask: collated_data['tok_groups'] = tok_groups collated_data['tok_group_labels'] = tok_group_labels else: collated_data['tok_groups'] = [] collated_data['tok_group_labels'] = [] return collated_data
vocab = Vocab(args.vocab_file_path) assert len(vocab) > 0 # Setup GPU use_cuda = True if args.cuda and torch.cuda.is_available() else False assert use_cuda # Trust me, you don't want to train this model on a cpu. device = torch.device("cuda" if use_cuda else "cpu") transform_imgs = resnet_img_transformation(args.img_crop_size) # Creating the data loader train_loader = DataLoader( Pix2CodeDataset(args.data_path, args.split, vocab, transform=transform_imgs), batch_size=args.batch_size, collate_fn=lambda data: collate_fn(data, vocab=vocab), pin_memory=True if use_cuda else False, num_workers=4, drop_last=True) print("Created data loader") # Creating the models embed_size = 256 hidden_size = 512 num_layers = 1 lr = args.lr encoder = Encoder(embed_size) decoder = Decoder(embed_size, hidden_size, len(vocab), num_layers) encoder = encoder.to(device)
def __init__(self): # dataset logger.info('-' * 100) logger.info('Loading training and validation dataset') self.dataset = data.CodePtrDataset(mode='train') self.dataset_size = len(self.dataset) logger.info('Size of training dataset: {}'.format(self.dataset_size)) self.dataloader = DataLoader(dataset=self.dataset, batch_size=config.batch_size, shuffle=True, collate_fn=lambda *args: utils.collate_fn( args, source_vocab=self.source_vocab, code_vocab=self.code_vocab, ast_vocab=self.ast_vocab, nl_vocab=self.nl_vocab)) # valid dataset self.valid_dataset = data.CodePtrDataset(mode='valid') self.valid_dataset_size = len(self.valid_dataset) self.valid_dataloader = DataLoader( dataset=self.valid_dataset, batch_size=config.valid_batch_size, collate_fn=lambda *args: utils.collate_fn( args, source_vocab=self.source_vocab, code_vocab=self.code_vocab, ast_vocab=self.ast_vocab, nl_vocab=self.nl_vocab)) logger.info('Size of validation dataset: {}'.format( self.valid_dataset_size)) logger.info('The dataset are successfully loaded') # vocab logger.info('-' * 100) logger.info('Building vocabularies') sources, codes, asts, nls = self.dataset.get_dataset() self.source_vocab = utils.build_word_vocab( dataset=sources, vocab_name='source', ignore_case=True, max_vocab_size=config.source_vocab_size, save_dir=config.vocab_root) self.source_vocab_size = len(self.source_vocab) logger.info('Size of source vocab: {} -> {}'.format( self.source_vocab.origin_size, self.source_vocab_size)) self.code_vocab = utils.build_word_vocab( dataset=codes, vocab_name='code', ignore_case=True, max_vocab_size=config.code_vocab_size, save_dir=config.vocab_root) self.code_vocab_size = len(self.code_vocab) logger.info('Size of code vocab: {} -> {}'.format( self.code_vocab.origin_size, self.code_vocab_size)) self.ast_vocab = utils.build_word_vocab(dataset=asts, vocab_name='ast', ignore_case=True, save_dir=config.vocab_root) self.ast_vocab_size = len(self.ast_vocab) logger.info('Size of ast vocab: {}'.format(self.ast_vocab_size)) self.nl_vocab = utils.build_word_vocab( dataset=nls, vocab_name='nl', ignore_case=True, max_vocab_size=config.nl_vocab_size, save_dir=config.vocab_root) self.nl_vocab_size = len(self.nl_vocab) logger.info('Size of nl vocab: {} -> {}'.format( self.nl_vocab.origin_size, self.nl_vocab_size)) logger.info('Vocabularies are successfully built') # model logger.info('-' * 100) logger.info('Building the model') self.model = models.Model(source_vocab_size=self.source_vocab_size, code_vocab_size=self.code_vocab_size, ast_vocab_size=self.ast_vocab_size, nl_vocab_size=self.nl_vocab_size) # model device logger.info('Model device: {}'.format( next(self.model.parameters()).device)) # log model statistic logger.info('Trainable parameters: {}'.format( utils.human_format(utils.count_params(self.model)))) # optimizer self.optimizer = Adam([ { 'params': self.model.parameters(), 'lr': config.learning_rate }, ]) self.criterion = nn.CrossEntropyLoss( ignore_index=self.nl_vocab.get_pad_index()) if config.use_lr_decay: self.lr_scheduler = lr_scheduler.StepLR(self.optimizer, step_size=1, gamma=config.lr_decay_rate) # early stopping self.early_stopping = None if config.use_early_stopping: self.early_stopping = utils.EarlyStopping( patience=config.early_stopping_patience, high_record=False)