def train(src, tgt, train_config, savedir, databin): # expect to have 'hyperparameters', 'src', 'tgt', 'databin' in train_config os.makedirs(savedir, exist_ok=True) logpath = os.path.join(savedir, 'train.log') checkpoint = os.path.join(savedir, 'checkpoint_best.pt') if check_last_line(logpath, 'done') and os.path.exists(checkpoint): print(f"Training is finished. Best checkpoint: {checkpoint}") return cuda_visible_devices = list(range(torch.cuda.device_count())) print("CVD: ", cuda_visible_devices) num_visible_gpu = len(cuda_visible_devices) print("NVG: ", num_visible_gpu) num_gpu = min(train_config['gpu'], 2**int(math.log2(num_visible_gpu))) print("NG: ", num_gpu) cuda_devices_clause = f"CUDA_VISIBLE_DEVICES={','.join([str(i) for i in cuda_visible_devices[:num_gpu]])}" print("CDC: ", cuda_devices_clause) update_freq = train_config['gpu'] / num_gpu print("Update freq: ", update_freq) call(f"""{cuda_devices_clause} fairseq-train {databin} \ --source-lang {src} --target-lang {tgt} \ --save-dir {savedir} \ --update-freq {update_freq} \ {" ".join(train_config['parameters'])} \ | tee {logpath} """, shell=True)
def main(): parser = argparse.ArgumentParser() parser.add_argument('--config', '-c', required=True, help='pipeline config') parser.add_argument('--databin', '-d', required=True, help='initial databin') args = parser.parse_args() configs = read_config(args.config) workdir = os.path.join(os.path.dirname(os.path.abspath(__file__)), '../experiments') #cuda_visible_devices=args.cuda_visible_devices or list(range(torch.cuda.device_count())) initial_databin = args.databin for i in range(len(configs)): (name, config) = configs[i] src = config['src'] tgt = config['tgt'] direction = f"{src}-{tgt}" print(f"Start {name} iteration, {direction}") iter_workdir = os.path.join(workdir, name, direction) # train model_dir = os.path.join(iter_workdir, 'model') train(src, tgt, config['train'], model_dir, initial_databin) checkpoint_path = os.path.join(model_dir, 'checkpoint_best.pt') # eval lenpen = config['translate']['lenpen'] eval_output = os.path.join(model_dir, 'eval.txt') if check_last_line(eval_output, "BLEU"): print( check_output(f"tail -n 1 {eval_output}", shell=True).decode('utf-8').strip()) else: print( eval_bleu(config['src'], config['tgt'], 'test', lenpen, args.databin, checkpoint_path, os.path.join(model_dir, 'eval.txt'))) # Early exit to skip back-translation for the last iteration if i == len(configs) - 1: break # translate translate_output = os.path.join(iter_workdir, 'synthetic') translate(config['src'], config['tgt'], checkpoint_path, lenpen, translate_output, config['translate']['mono'], config['translate']['max_token']) # generate databin databin_folder = os.path.join(translate_output, 'bt') initial_databin = build_bt_databin( config['tgt'], config['src'], os.path.join(translate_output, 'generated'), args.databin, databin_folder)
def check_finished(output_file): return check_last_line(output_file, "finished")