Ejemplo n.º 1
0
def main(DEVICE):
    """
    main function

    :param DEVICE: 'cpu' or 'gpu'

    """
    model = TPGST().to(DEVICE)

    print('Model {} is working...'.format(type(model).__name__))
    ckpt_dir = os.path.join(args.logdir, type(model).__name__)

    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    scheduler = LambdaLR(optimizer, lr_policy)

    if not os.path.exists(ckpt_dir):
        os.makedirs(os.path.join(ckpt_dir, 'A', 'train'))
    else:
        print('Already exists. Retrain the model.')
        model_path = sorted(glob.glob(os.path.join(
            ckpt_dir, 'model-*.tar')))[-1]  # latest model
        state = torch.load(model_path)
        model.load_state_dict(state['model'])
        args.global_step = state['global_step']
        optimizer.load_state_dict(state['optimizer'])
        scheduler.last_epoch = state['scheduler']['last_epoch']
        scheduler.base_lrs = state['scheduler']['base_lrs']

    dataset = SpeechDataset(args.data_path,
                            args.meta,
                            mem_mode=args.mem_mode,
                            training=True)
    validset = SpeechDataset(args.data_path,
                             args.meta,
                             mem_mode=args.mem_mode,
                             training=False)
    data_loader = DataLoader(dataset=dataset,
                             batch_size=args.batch_size,
                             shuffle=True,
                             collate_fn=collate_fn,
                             drop_last=True,
                             pin_memory=True,
                             num_workers=args.n_workers)
    valid_loader = DataLoader(dataset=validset,
                              batch_size=args.test_batch,
                              shuffle=False,
                              collate_fn=collate_fn,
                              pin_memory=True)
    # torch.set_num_threads(4)
    print('{} threads are used...'.format(torch.get_num_threads()))

    writer = SummaryWriter(ckpt_dir)
    train(model,
          data_loader,
          valid_loader,
          optimizer,
          scheduler,
          batch_size=args.batch_size,
          ckpt_dir=ckpt_dir,
          writer=writer,
          DEVICE=DEVICE)
    return None