Ejemplo n.º 1
0
def main(argv):
    print("Torch GPU set:", argv.gpu_id)
    torch.manual_seed(argv.seed)
    print("Torch Seed was:", argv.seed)
    utils.mkdir(os.path.join(argv.res_root, argv.model_folder))

    loader = DataLoader(argv)
    loader.read_data()
    argv.desired_len = loader.desired_len
    # assert loader.max_idx == argv.max_idx

    is_nn = (argv.model != 'nb')
    model = create_model(argv, loader)

    # load chkpt
    checkpoint = None
    if argv.mode == 'test' or argv.use_checkpoint:

        chkpt_load_path = None
        for file in os.listdir(os.path.join(argv.res_root, argv.model_folder)):
            if f"{argv.checkpoint_name}_epoch{argv.checkpoint_ver}" in file:
                chkpt_load_path = os.path.join(argv.res_root,
                                               argv.model_folder, file)
                break
        if chkpt_load_path is None:
            raise Exception("Can't find checkpoint")

        print(f"\tLoading {chkpt_load_path}")
        checkpoint = torch.load(chkpt_load_path)
        # old argv content that still want to keep
        checkpoint['argv'].output = argv.output
        checkpoint['argv'].mode = argv.mode
        checkpoint['argv'].use_checkpoint = argv.use_checkpoint
        assert checkpoint['epoch'] == argv.checkpoint_ver
        checkpoint['argv'].checkpoint_ver = argv.checkpoint_ver

        argv = checkpoint['argv']
        epoch = checkpoint['epoch'] + 1
        model = create_model(argv, loader)
        loader.desired_len = argv.desired_len
        loader.batch_size = argv.batch_size

    if USE_CUDA and is_nn:
        torch.cuda.set_device(argv.gpu_id)
        model = model.cuda()

    print(f"\n{argv.mode} {type(model).__name__} {'#'*50}")
    if argv.mode == 'test':
        model.load_state_dict(checkpoint['model_state_dict'])
        test_model(model, loader, is_nn)
    else:
        train_model(model, loader, checkpoint, is_nn)

    print()
Ejemplo n.º 2
0
from argparse import Namespace

data_path = "./src_zzc/data/Training.txt"
dest = "./src_zzc/data_ensemble"

argv = Namespace()
argv.data_path = data_path
argv.batch_size = 50
argv.fold = 5
argv.mode = 'train'
argv.desired_len_percent = 0.5
argv.num_label = 20
argv.seed = 1587064891

loader = DataLoader(argv)
loader.read_data()

train = open(os.path.join(dest, 'Training.txt'), 'w')
train_label = open(os.path.join(dest, 'Training_Label.txt'), 'w')
for gt, msg in loader.data:
    train_label.write(f"{gt}\n")
    train.write(",".join([str(word) for word in msg]))
    train.write("\n")

train.close()
train_label.close()

train = open(os.path.join(dest, 'Training_val.txt'), 'w')
train_label = open(os.path.join(dest, 'Training_val_Label.txt'), 'w')
for gt, msg in loader.data_val:
    train_label.write(f"{gt}\n")