Esempio n. 1
0
def train(dataset_name: str,
          model_name: str,
          expt_dir: str,
          data_folder: str,
          num_workers: int = 0,
          is_test: bool = False,
          resume_from_checkpoint: str = None):
    seed_everything(SEED)
    dataset_main_folder = data_folder
    vocab = Vocabulary.load(join(dataset_main_folder, "vocabulary.pkl"))

    if model_name == "code2seq":
        config_function = get_code2seq_test_config if is_test else get_code2seq_default_config
        config = config_function(dataset_main_folder)
        model = Code2Seq(config, vocab, num_workers)
        model.half()
    #elif model_name == "code2class":
    #	config_function = get_code2class_test_config if is_test else get_code2class_default_config
    #	config = config_function(dataset_main_folder)
    #	model = Code2Class(config, vocab, num_workers)
    else:
        raise ValueError(f"Model {model_name} is not supported")

    # define logger
    wandb_logger = WandbLogger(project=f"{model_name}-{dataset_name}",
                               log_model=True,
                               offline=True)
    wandb_logger.watch(model)
    # define model checkpoint callback
    model_checkpoint_callback = ModelCheckpoint(
        filepath=join(expt_dir, "{epoch:02d}-{val_loss:.4f}"),
        period=config.hyperparams.save_every_epoch,
        save_top_k=3,
    )
    # define early stopping callback
    early_stopping_callback = EarlyStopping(
        patience=config.hyperparams.patience, verbose=True, mode="min")
    # use gpu if it exists
    gpu = 1 if torch.cuda.is_available() else None
    # define learning rate logger
    lr_logger = LearningRateLogger()
    trainer = Trainer(
        max_epochs=20,
        gradient_clip_val=config.hyperparams.clip_norm,
        deterministic=True,
        check_val_every_n_epoch=config.hyperparams.val_every_epoch,
        row_log_interval=config.hyperparams.log_every_epoch,
        logger=wandb_logger,
        checkpoint_callback=model_checkpoint_callback,
        early_stop_callback=early_stopping_callback,
        resume_from_checkpoint=resume_from_checkpoint,
        gpus=gpu,
        callbacks=[lr_logger],
        reload_dataloaders_every_epoch=True,
    )
    trainer.fit(model)
    trainer.save_checkpoint(join(expt_dir, 'Latest.ckpt'))

    trainer.test()
Esempio n. 2
0
def evaluate(checkpoint: str, data: str = None, batch_size: int = None):
    seed_everything(SEED)
    model = Code2Seq.load_from_checkpoint(checkpoint_path=checkpoint)
    batch_size = batch_size or model.hyperparams.test_batch_size
    data = data or model.hyperparams.test_data_path
    gpu = 1 if torch.cuda.is_available() else None
    data_loader, n_samples = create_dataloader(
        data, model.hyperparams.max_context, False, False, batch_size, cpu_count(),
    )
    print(f"approximate number of steps for test is {ceil(n_samples / batch_size)}")
    trainer = Trainer(gpus=gpu)
    trainer.test(model, test_dataloaders=data_loader)
Esempio n. 3
0
def evaluate(checkpoint: str, data: str = None):
    seed_everything(SEED)
    model = Code2Seq.load_from_checkpoint(checkpoint_path=checkpoint)
    gpu = 1 if torch.cuda.is_available() else None
    trainer = Trainer(gpus=gpu)
    if data is not None:
        data_loader, n_samples = create_dataloader(
            join(DATA_FOLDER, data), model.config.max_context, False, False,
            model.config.test_batch_size, cpu_count())
        print(
            f"approximate number of steps for test is {ceil(n_samples / model.config.test_batch_size)}"
        )
        trainer.test(model, test_dataloaders=data_loader)
    else:
        trainer.test(model)
Esempio n. 4
0
def get_code2seq(config: DictConfig, vocabulary: Vocabulary) -> Tuple[LightningModule, LightningDataModule]:
    model = Code2Seq(config, vocabulary)
    data_module = PathContextDataModule(config, vocabulary)
    return model, data_module
Esempio n. 5
0
def load_code2seq(
        checkpoint_path: str, config: DictConfig,
        vocabulary: Vocabulary) -> Tuple[Code2Seq, PathContextDataModule]:
    model = Code2Seq.load_from_checkpoint(checkpoint_path=checkpoint_path)
    data_module = PathContextDataModule(config, vocabulary)
    return model, data_module
				
	print('Done dumping reduced data set')
	return out_path


if __name__=="__main__":
	opt = parse_args()
	print(opt)
	print('data path: ', opt.data_path)
	data_split = opt.data_path.split('/')[-2]
	print('data_split', data_split)

	# replace_tokens = ["@R_%d@"%x for x in range(0,opt.num_replacements+1)]
	replace_tokens = ["@R_%d@"%x for x in range(1000)]
	
	model = Code2Seq.load_from_checkpoint(checkpoint_path=opt.expt_dir)

	data_loader, n_samples = create_dataloader(
		opt.data_path, model.hyperparams.max_context, False, False, opt.batch_size, 1,
	)

	vocab = pickle.load(open(opt.vocab, 'rb'))
	token_to_id = vocab['token_to_id']
	id_to_token = {token_to_id[t]:t for t in token_to_id}
	print('length: ', len(id_to_token))
	label_to_id = vocab['label_to_id']
	id_to_label = {label_to_id[t]:t for t in label_to_id}


	# if data_split == 'test' and opt.exact_matches:
Esempio n. 7
0
    args = parser.parse_args()
    return args


def create_datafile(data_path, exact_matches, split):

    new_data_path = os.path.join(data_path, 'small.{}.c2s'.format(split))
    lines = open(os.path.join(data_path, 'data.{}.c2s'.format(split)), 'r')
    new_file = open(new_data_path, 'w')
    for line in lines:
        if line.split()[0] in exact_matches:
            new_file.write(line)
    print("Saved exact matches.")


if __name__ == '__main__':

    args = parse_args()
    model = Code2Seq.load_from_checkpoint(checkpoint_path=args.checkpoint)
    data_loader, n_samples = create_dataloader(
        os.path.join(args.orig_data_path, args.split),
        model.hyperparams.max_context, False, False, args.batch_size, 1)
    vocab = pickle.load(open(args.vocab_path, 'rb'))
    label_to_id = vocab['label_to_id']
    id_to_label = {label_to_id[l]: l for l in label_to_id}

    li_exact_matches = get_exact_matches(data_loader, n_samples, model,
                                         id_to_label)
    print(li_exact_matches)
    create_datafile(args.data_path, li_exact_matches, args.split)
Esempio n. 8
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--epoch",
                        "-e",
                        type=int,
                        default=1,
                        help="Number of examples in epoch")
    parser.add_argument("--batchsize",
                        "-b",
                        type=int,
                        default=1,
                        help="Number of examples in each mini-batch")
    parser.add_argument("--gpu",
                        "-g",
                        type=int,
                        default=-1,
                        help="GPU ID (negative value indicates CPU)")
    parser.add_argument("--out",
                        "-o",
                        default="result",
                        help="Directory to output the result")
    parser.add_argument("--resume",
                        "-r",
                        default="",
                        help="Resume the training from snapshot")
    # bilstmなので*2
    parser.add_argument("--path_rnn_size",
                        type=int,
                        default=128 * 2,
                        help="embedding size of path")
    parser.add_argument("--embed_size",
                        type=int,
                        default=128,
                        help="embedding size")
    parser.add_argument("--datapath",
                        default="./java-small/java-small.dict.c2v",
                        help="path of input data")
    parser.add_argument("--trainpath",
                        default="./data/java14m/java14m.train.c2v",
                        help="path of train data")
    parser.add_argument("--validpath",
                        default="./data/java14m/java14m.val.c2v",
                        help="path of valid data")
    parser.add_argument("--savename", default="", help="name of saved model")
    parser.add_argument("--trainnum",
                        type=int,
                        default=15344512,
                        help="size of train data")
    parser.add_argument("--validnum",
                        type=int,
                        default=320866,
                        help="size of valid data")
    parser.add_argument("--context_length",
                        type=int,
                        default=200,
                        help="length of context")
    parser.add_argument("--terminal_length",
                        type=int,
                        default=5,
                        help="length of terminal")
    parser.add_argument("--path_length",
                        type=int,
                        default=9,
                        help="length of path")
    parser.add_argument("--target_length",
                        type=int,
                        default=8,
                        help="length of target")
    parser.add_argument("--eval", action="store_true", help="is eval")
    parser.add_argument("--path_rnn_drop",
                        type=float,
                        default=0.5,
                        help="drop rate of path rnn")
    parser.add_argument("--embed_drop",
                        type=float,
                        default=0.25,
                        help="drop rate of embbeding")
    parser.add_argument("--num_worker",
                        type=int,
                        default=0,
                        help="the number of worker")
    parser.add_argument("--decode_size",
                        type=int,
                        default=320,
                        help="decode size")

    args = parser.parse_args()

    device = torch.device(
        args.gpu if args.gpu != -1 and torch.cuda.is_available() else "cpu")
    print(device)

    with open(args.datapath, "rb") as file:
        terminal_counter = pickle.load(file)
        path_counter = pickle.load(file)
        target_counter = pickle.load(file)
        # _ = pickle.load(file)
        # _ = pickle.load(file)
        print("Dictionaries loaded.")
    train_h5 = h5py.File(args.trainpath, "r")
    test_h5 = h5py.File(args.validpath, "r")

    terminal_dict = {
        w: i
        for i, w in enumerate(sorted([w for w, c in terminal_counter.items()]))
    }
    terminal_dict["<unk>"] = len(terminal_dict)
    terminal_dict["<pad>"] = len(terminal_dict)
    path_dict = {w: i for i, w in enumerate(sorted(path_counter.keys()))}
    path_dict["<unk>"] = len(path_dict)
    path_dict["<pad>"] = len(path_dict)
    target_dict = {
        w: i
        for i, w in enumerate(sorted([w for w, c in target_counter.items()]))
    }
    target_dict["<unk>"] = len(target_dict)
    target_dict["<bos>"] = len(target_dict)
    target_dict["<pad>"] = len(target_dict)

    print("terminal_vocab:", len(terminal_dict))
    print("target_vocab:", len(target_dict))

    c2s = Code2Seq(args, terminal_vocab_size=len(terminal_dict),
                   path_element_vocab_size=len(path_dict),
                   target_dict=target_dict, device=device,
                   path_embed_size=args.embed_size,
                   terminal_embed_size=args.embed_size,
                   path_rnn_size=args.path_rnn_size,
                   target_embed_size=args.embed_size,
                   decode_size=args.decode_size)\
        .to(device)

    if args.resume != "":
        c2s.load_state_dict(torch.load(args.resume))

    trainloader = DataLoader(C2SDataSet(args, train_h5, args.trainnum,
                                        terminal_dict, path_dict, target_dict,
                                        device),
                             batch_size=args.batchsize,
                             shuffle=True,
                             num_workers=args.num_worker)

    validloader = DataLoader(C2SDataSet(args, test_h5, args.validnum,
                                        terminal_dict, path_dict, target_dict,
                                        device),
                             batch_size=args.batchsize,
                             shuffle=True,
                             num_workers=args.num_worker)

    optimizer = optim.SGD(c2s.parameters(), lr=0.01, momentum=0.95)
    scheduler = optim.lr_scheduler.StepLR(optimizer,
                                          step_size=1,
                                          gamma=0.95,
                                          last_epoch=-1)

    for epoch in range(args.epoch):
        if not args.eval:
            trainloader = DataLoader(C2SDataSet(args, train_h5, args.trainnum,
                                                terminal_dict, path_dict,
                                                target_dict, device),
                                     batch_size=args.batchsize,
                                     shuffle=True,
                                     num_workers=args.num_worker)

            sum_loss = 0
            train_count = 0
            c2s.train()
            scheduler.step()  # epochごとなのでここ
            for data in tqdm.tqdm(trainloader):
                loss = c2s(*data, is_eval=False)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                sum_loss += loss.item()
                if train_count % 250 == 0 and train_count != 0:
                    print(sum_loss / 250)
                    sum_loss = 0
                train_count += 1
        true_positive, false_positive, false_negative = 0, 0, 0
        for data in tqdm.tqdm(validloader):
            c2s.eval()
            with torch.no_grad():
                true_positive_, false_positive_, false_negative_ = c2s(
                    *data, is_eval=True)
            true_positive += true_positive_
            false_positive += false_positive_
            false_negative += false_negative_

        pre_score, rec_score, f1_score = calculate_results(
            true_positive, false_positive, false_negative)
        print("f1:", f1_score, "prec:", pre_score, "rec:", rec_score)
        if args.eval:
            break
        if args.savename != "":
            torch.save(c2s.state_dict(), args.savename + str(epoch) + ".model")
    if args.savename != "":
        torch.save(c2s.state_dict(), args.savename + ".model")
                        help="size of train data")
    parser.add_argument("--validnum",
                        type=int,
                        default=23844,
                        help="size of valid data")
    args = parser.parse_args()

    seed = 7
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    dicts = Dictionaries(config)

    # model = torch.load("data/1590690332408_iteration_6800_epoch_0.tar", map_location=torch.device('cpu'))
    model = Code2Seq(dicts).to(device)
    model.train(True)

    criterion = nn.CrossEntropyLoss(reduction='none')
    optimizer = optim.Adam(model.parameters())

    train_h5 = h5py.File(args.trainpath, 'r')
    val_h5 = h5py.File(args.validpath, 'r')

    train_set = C2SDataSet(config, train_h5, args.trainnum,
                           dicts.subtoken_to_index, dicts.node_to_index,
                           dicts.target_to_index, device)

    # train_set = Subset(train_set, list(range(1280)))

    val_set = C2SDataSet(config, val_h5, args.validnum,