def main(DEVICE): """ main function :param DEVICE: 'cpu' or 'gpu' """ model = TPGST().to(DEVICE) print('Model {} is working...'.format(type(model).__name__)) ckpt_dir = os.path.join(args.logdir, type(model).__name__) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) scheduler = LambdaLR(optimizer, lr_policy) if not os.path.exists(ckpt_dir): os.makedirs(os.path.join(ckpt_dir, 'A', 'train')) else: print('Already exists. Retrain the model.') model_path = sorted(glob.glob(os.path.join( ckpt_dir, 'model-*.tar')))[-1] # latest model state = torch.load(model_path) model.load_state_dict(state['model']) args.global_step = state['global_step'] optimizer.load_state_dict(state['optimizer']) scheduler.last_epoch = state['scheduler']['last_epoch'] scheduler.base_lrs = state['scheduler']['base_lrs'] dataset = SpeechDataset(args.data_path, args.meta, mem_mode=args.mem_mode, training=True) validset = SpeechDataset(args.data_path, args.meta, mem_mode=args.mem_mode, training=False) data_loader = DataLoader(dataset=dataset, batch_size=args.batch_size, shuffle=True, collate_fn=collate_fn, drop_last=True, pin_memory=True, num_workers=args.n_workers) valid_loader = DataLoader(dataset=validset, batch_size=args.test_batch, shuffle=False, collate_fn=collate_fn, pin_memory=True) # torch.set_num_threads(4) print('{} threads are used...'.format(torch.get_num_threads())) writer = SummaryWriter(ckpt_dir) train(model, data_loader, valid_loader, optimizer, scheduler, batch_size=args.batch_size, ckpt_dir=ckpt_dir, writer=writer, DEVICE=DEVICE) return None