def main(args): # Load configs model_cls_list = [models.get_model(model) for model in args.models] params_list = [default_params() for _ in range(len(model_cls_list))] params_list = [ merge_params(params, model_cls.default_params()) for params, model_cls in zip(params_list, model_cls_list)] params_list = [ import_params(args.checkpoints[i], args.models[i], params_list[i]) for i in range(len(args.checkpoints))] params_list = [ override_params(params_list[i], args) for i in range(len(model_cls_list))] params = params_list[0] dist.init_process_group("nccl", init_method=args.url, rank=args.local_rank, world_size=len(params.device_list)) torch.cuda.set_device(params.device_list[args.local_rank]) torch.set_default_tensor_type(torch.cuda.FloatTensor) if args.half: torch.set_default_dtype(torch.half) torch.set_default_tensor_type(torch.cuda.HalfTensor) # Create model with torch.no_grad(): model_list = [] if len(args.input) == 1: mode = "infer" if params.from_torchtext: dataset = data.get_dataset_torchtext(args.input[0], mode, params) else: dataset = data.get_dataset(args.input[0], mode, params) else: # Teacher-forcing mode = "eval" if params.from_torchtext: dataset = data.get_dataset_torchtext(args.input, mode, params) else: dataset = data.get_dataset(args.input, mode, params) iterator = iter(dataset) idx = 0 counter = 0 pad_max = 1024 top_beams = params.top_beams decode_batch_size = params.decode_batch_size # count eval dataset total_len = 0 max_length = 0 for sample in iterator: total_len += 1 length = sample['source'].shape[1] if length > max_length: max_length = length iterator = iter(dataset) for param in params_list: if hasattr(param, "max_length"): param.max_length = min(param.max_length, max_length) else: param.max_length = max_length for i in range(len(args.models)): model = model_cls_list[i](params_list[i]).cuda() if args.half: model = model.half() model.eval() model.load_state_dict( torch.load(utils.latest_checkpoint(args.checkpoints[i]), map_location="cpu")["model"]) model_list.append(model) # Buffers for synchronization size = torch.zeros([dist.get_world_size()]).long() t_list = [torch.empty([decode_batch_size, top_beams, pad_max]).long() for _ in range(dist.get_world_size())] if dist.get_rank() == 0: fd = open(args.output, "wb") pbar = tqdm(total=total_len) pbar.set_description("Translating to {}".format(args.output)) else: fd = None states = [None for _ in model_list] if "cachedtransformer" in [model.name for model in model_list]: last_features = [None for _ in model_list] for model in model_list: if model.name == "cachedtransformer": model.encoder.cache.set_batch_size(params.decode_batch_size) model.decoder.cache.set_batch_size(params.decode_batch_size) while True: try: features = next(iterator) features = data.lookup(features, mode, params, from_torchtext=params.from_torchtext) if mode == "eval": features = features[0] batch_size = features["source"].shape[0] except: features = { "source": torch.ones([1, 1]).long(), "source_mask": torch.ones([1, 1]).float() } if mode == "eval": features["target"] = torch.ones([1, 1]).long() features["target_mask"] = torch.ones([1, 1]).float() batch_size = 0 finally: for im, model in enumerate(model_list): if model.name == "cachedtransformer": features = update_cache(model, features, states[im], last_features[im], evaluate=True) last_features[im] = features counter += 1 # Decode if mode != "eval": seqs, _, states = utils.beam_search(model_list, features, params) else: seqs, _ = utils.argmax_decoding(model_list, features, params) # Padding pad_batch = decode_batch_size - seqs.shape[0] pad_beams = top_beams - seqs.shape[1] pad_length = pad_max - seqs.shape[2] seqs = torch.nn.functional.pad( seqs, (0, pad_length, 0, pad_beams, 0, pad_batch)) # Synchronization size.zero_() size[dist.get_rank()].copy_(torch.tensor(batch_size)) dist.all_reduce(size) dist.all_gather(t_list, seqs) if size.sum() == 0: break if dist.get_rank() != 0: continue for i in range(decode_batch_size): for j in range(dist.get_world_size()): for k in range(top_beams): n = size[j] seq = convert_to_string(t_list[j][i][k], params) if i >= n: continue if top_beams == 1: fd.write(seq) fd.write(b"\n") else: fd.write(str(idx).encode("utf-8")) fd.write(b"\t") fd.write(str(k).encode("utf-8")) fd.write(b"\t") fd.write(seq) fd.write(b"\n") idx = idx + 1 if dist.get_rank() == 0: pbar.update(1) if dist.get_rank() == 0: pbar.close() fd.close()
def main(args): # Load configs model_cls_list = [models.get_model(model) for model in args.models] params_list = [default_params() for _ in range(len(model_cls_list))] params_list = [ merge_params(params, model_cls.default_params()) for params, model_cls in zip(params_list, model_cls_list) ] params_list = [ import_params(args.checkpoints[i], args.models[i], params_list[i]) for i in range(len(args.checkpoints)) ] params_list = [ override_params(params_list[i], args) for i in range(len(model_cls_list)) ] params = params_list[0] if args.cpu: dist.init_process_group("gloo", init_method=args.url, rank=args.local_rank, world_size=1) torch.set_default_tensor_type(torch.FloatTensor) else: dist.init_process_group("nccl", init_method=args.url, rank=args.local_rank, world_size=len(params.device_list)) torch.cuda.set_device(params.device_list[args.local_rank]) torch.set_default_tensor_type(torch.cuda.FloatTensor) if args.half: torch.set_default_dtype(torch.half) torch.set_default_tensor_type(torch.cuda.HalfTensor) # Create model with torch.no_grad(): model_list = [] for i in range(len(args.models)): if args.cpu: model = model_cls_list[i](params_list[i]) else: model = model_cls_list[i](params_list[i]).cuda() if args.half: model = model.half() model.eval() model.load_state_dict( torch.load(utils.latest_checkpoint(args.checkpoints[i]), map_location="cpu")["model"]) model_list.append(model) if len(args.input) == 1: mode = "infer" sorted_key, dataset = data.get_dataset(args.input[0], mode, params) else: # Teacher-forcing mode = "eval" dataset = data.get_dataset(args.input, mode, params) sorted_key = None iterator = iter(dataset) counter = 0 pad_max = 1024 top_beams = params.top_beams decode_batch_size = params.decode_batch_size # Buffers for synchronization size = torch.zeros([dist.get_world_size()]).long() t_list = [ torch.empty([decode_batch_size, top_beams, pad_max]).long() for _ in range(dist.get_world_size()) ] all_outputs = [] while True: try: features = next(iterator) features = data.lookup(features, mode, params, to_cpu=args.cpu) if mode == "eval": features = features[0] batch_size = features["source"].shape[0] except: features = { "source": torch.ones([1, 1]).long(), "source_mask": torch.ones([1, 1]).float() } if mode == "eval": features["target"] = torch.ones([1, 1]).long() features["target_mask"] = torch.ones([1, 1]).float() batch_size = 0 t = time.time() counter += 1 # Decode if mode != "eval": seqs, _ = utils.beam_search(model_list, features, params) else: seqs, _ = utils.argmax_decoding(model_list, features, params) # Padding pad_batch = decode_batch_size - seqs.shape[0] pad_beams = top_beams - seqs.shape[1] pad_length = pad_max - seqs.shape[2] seqs = torch.nn.functional.pad( seqs, (0, pad_length, 0, pad_beams, 0, pad_batch)) # Synchronization size.zero_() size[dist.get_rank()].copy_(torch.tensor(batch_size)) if args.cpu: t_list[dist.get_rank()].copy_(seqs) else: dist.all_reduce(size) dist.all_gather(t_list, seqs) if size.sum() == 0: break if dist.get_rank() != 0: continue for i in range(decode_batch_size): for j in range(dist.get_world_size()): beam_seqs = [] pad_flag = i >= size[j] for k in range(top_beams): seq = convert_to_string(t_list[j][i][k], params) if pad_flag: continue beam_seqs.append(seq) if pad_flag: continue all_outputs.append(beam_seqs) t = time.time() - t print("Finished batch: %d (%.3f sec)" % (counter, t)) if dist.get_rank() == 0: restored_outputs = [] if sorted_key is not None: for idx in range(len(all_outputs)): restored_outputs.append(all_outputs[sorted_key[idx]]) else: restored_outputs = all_outputs with open(args.output, "wb") as fd: if top_beams == 1: for seqs in restored_outputs: fd.write(seqs[0] + b"\n") else: for idx, seqs in enumerate(restored_outputs): for k, seq in enumerate(seqs): fd.write(b"%d\t%d\t" % (idx, k)) fd.write(seq + b"\n")
def main(args): model_cls = models.get_model(args.model) # Import and override parameters # Priorities (low -> high): # default -> saved -> command params = default_params() params = merge_params(params, model_cls.default_params(args.hparam_set)) params = import_params(args.output, args.model, params) params = override_params(params, args) # Initialize distributed utility if args.distributed: dist.init_process_group("nccl") torch.cuda.set_device(args.local_rank) torch.set_default_tensor_type(torch.cuda.FloatTensor) else: dist.init_process_group("nccl", init_method=args.url, rank=args.local_rank, world_size=len(params.device_list)) torch.cuda.set_device(params.device_list[args.local_rank]) torch.set_default_tensor_type(torch.cuda.FloatTensor) # Export parameters if dist.get_rank() == 0: export_params(params.output, "params.json", params) export_params(params.output, "%s.json" % params.model, collect_params(params, model_cls.default_params())) model = model_cls(params).cuda() if args.half: model = model.half() torch.set_default_dtype(torch.half) torch.set_default_tensor_type(torch.cuda.HalfTensor) model.train() # Init tensorboard summary.init(params.output, params.save_summary) schedule = get_learning_rate_schedule(params) clipper = get_clipper(params) if params.optimizer.lower() == "adam": optimizer = optimizers.AdamOptimizer(learning_rate=schedule, beta_1=params.adam_beta1, beta_2=params.adam_beta2, epsilon=params.adam_epsilon, clipper=clipper, summaries=params.save_summary) elif params.optimizer.lower() == "adadelta": optimizer = optimizers.AdadeltaOptimizer( learning_rate=schedule, rho=params.adadelta_rho, epsilon=params.adadelta_epsilon, clipper=clipper, summaries=params.save_summary) elif params.optimizer.lower() == "sgd": optimizer = optimizers.SGDOptimizer(learning_rate=schedule, clipper=clipper, summaries=params.save_summary) else: raise ValueError("Unknown optimizer %s" % params.optimizer) if args.half: optimizer = optimizers.LossScalingOptimizer(optimizer) optimizer = optimizers.MultiStepOptimizer(optimizer, params.update_cycle) if dist.get_rank() == 0: print_variables(model) if params.from_torchtext: dataset = data.get_dataset_torchtext(params.input, "train", params) else: dataset = data.get_dataset(params.input, "train", params) if params.validation: if params.from_torchtext: eval_dataset = data.get_dataset_torchtext(params.validation, "infer", params) else: eval_dataset = data.get_dataset(params.validation, "infer", params) references = load_references(params.references) else: eval_dataset = None references = None # Load checkpoint checkpoint = utils.latest_checkpoint(params.output) if args.checkpoint is not None: # Load pre-trained models state = torch.load(args.checkpoint, map_location="cpu") model.load_state_dict(state["model"], strict=False) step = params.initial_step epoch = 0 broadcast(model) elif checkpoint is not None: state = torch.load(checkpoint, map_location="cpu") step = state["step"] epoch = state["epoch"] model.load_state_dict(state["model"]) if "optimizer" in state: optimizer.load_state_dict(state["optimizer"]) else: step = 0 epoch = 0 broadcast(model) def train_fn(inputs): features, labels = inputs loss, state = model(features, labels) return loss, state counter = 0 state = None if params.model == "cachedtransformer": last_feature = None while True: start_time = time.time() for features in dataset: if counter % params.update_cycle == 0: step += 1 utils.set_global_step(step) counter += 1 t = time.time() features = data.lookup(features, "train", params, from_torchtext=params.from_torchtext) if model.name == "cachedtransformer": features = utils.update_cache(model, features, state, last_feature) last_feature = features[0] loss, state = train_fn(features) gradients = optimizer.compute_gradients(loss, list(model.parameters())) grads_and_vars = optimizers.exclude_variables( params.pattern, zip(gradients, list(model.named_parameters()))) optimizer.apply_gradients(grads_and_vars) t = time.time() - t summary.scalar("loss", loss, step, write_every_n_steps=1) summary.scalar("global_step/sec", t, step) if counter % params.update_cycle == 0: if step > 0 and step % args.log_interval == 0: elapsed = time.time() - start_time print('| epoch {:2d} | step {:8d} | lr {:02.2e} | ' 'ms/step {:4.0f} | loss {:8.4f} '.format( epoch + 1, step, optimizer._optimizer._learning_rate(step), elapsed * 1000 / args.log_interval, loss.item())) start_time = time.time() if step >= params.train_steps: utils.evaluate(model, eval_dataset, params.output, references, params) save_checkpoint(step, epoch, model, optimizer, params) if dist.get_rank() == 0: summary.close() return if step % params.eval_steps == 0: utils.evaluate(model, eval_dataset, params.output, references, params) start_time = time.time() if step % params.save_checkpoint_steps == 0: save_checkpoint(step, epoch, model, optimizer, params) start_time = time.time() epoch += 1
def main(args): model_cls = models.get_model(args.model) # Import and override parameters # Priorities (low -> high): # default -> saved -> command params = default_params() params = merge_params(params, model_cls.default_params(args.hparam_set)) params = import_params(args.output, args.model, params) params = override_params(params, args) # Initialize distributed utility if args.distributed: dist.init_process_group("nccl") torch.cuda.set_device(args.local_rank) torch.set_default_tensor_type(torch.cuda.FloatTensor) else: dist.init_process_group("nccl", init_method=args.url, rank=args.local_rank, world_size=len(params.device_list)) torch.cuda.set_device(params.device_list[args.local_rank]) torch.set_default_tensor_type(torch.cuda.FloatTensor) # Export parameters if dist.get_rank() == 0: export_params(params.output, "params.json", params) export_params(params.output, "%s.json" % params.model, collect_params(params, model_cls.default_params())) model = model_cls(params).cuda() if args.half: model = model.half() torch.set_default_dtype(torch.half) torch.set_default_tensor_type(torch.cuda.HalfTensor) model.train() # Init tensorboard summary.init(params.output, params.save_summary) schedule = get_learning_rate_schedule(params) clipper = get_clipper(params) optimizer = get_optimizer(params, schedule, clipper) if args.half: optimizer = optimizers.LossScalingOptimizer(optimizer) optimizer = optimizers.MultiStepOptimizer(optimizer, params.update_cycle) trainable_flags = print_variables(model, params.pattern, dist.get_rank() == 0) dataset = data.get_dataset(params.input, "train", params) if params.validation: sorted_key, eval_dataset = data.get_dataset(params.validation, "infer", params) references = load_references(params.references) else: sorted_key = None eval_dataset = None references = None # Load checkpoint checkpoint = utils.latest_checkpoint(params.output) if args.checkpoint is not None: # Load pre-trained models state = torch.load(args.checkpoint, map_location="cpu") model.load_state_dict(state["model"]) step = params.initial_step epoch = 0 broadcast(model) elif checkpoint is not None: state = torch.load(checkpoint, map_location="cpu") step = state["step"] epoch = state["epoch"] model.load_state_dict(state["model"]) if "optimizer" in state: optimizer.load_state_dict(state["optimizer"]) else: step = 0 epoch = 0 broadcast(model) def train_fn(inputs): features, labels = inputs loss = model(features, labels) return loss counter = 0 while True: for features in dataset: if counter % params.update_cycle == 0: step += 1 utils.set_global_step(step) counter += 1 t = time.time() features = data.lookup(features, "train", params) loss = train_fn(features) gradients = optimizer.compute_gradients(loss, list(model.parameters())) grads_and_vars = exclude_variables( trainable_flags, zip(gradients, list(model.named_parameters()))) optimizer.apply_gradients(grads_and_vars) t = time.time() - t summary.scalar("loss", loss, step, write_every_n_steps=1) summary.scalar("global_step/sec", t, step) print("epoch = %d, step = %d, loss = %.3f (%.3f sec)" % (epoch + 1, step, float(loss), t)) if counter % params.update_cycle == 0: if step >= params.train_steps: utils.evaluate(model, sorted_key, eval_dataset, params.output, references, params) save_checkpoint(step, epoch, model, optimizer, params) if dist.get_rank() == 0: summary.close() return if step % params.eval_steps == 0: utils.evaluate(model, sorted_key, eval_dataset, params.output, references, params) if step % params.save_checkpoint_steps == 0: save_checkpoint(step, epoch, model, optimizer, params) epoch += 1
def main(args): model_cls = models.get_model(args.model) # Import and override parameters # Priorities (low -> high): # default -> saved -> command params = default_params() params = merge_params(params, model_cls.default_params()) params = import_params(args.checkpoint, args.model, params) params = override_params(params, args) dist.init_process_group("nccl", init_method=args.url, rank=args.local_rank, world_size=len(params.device_list)) torch.cuda.set_device(params.device_list[args.local_rank]) torch.set_default_tensor_type(torch.cuda.FloatTensor) if args.half: torch.set_default_dtype(torch.half) torch.set_default_tensor_type(torch.cuda.HalfTensor) def score_fn(inputs, _model, level="sentence"): _features, _labels = inputs _score = _model(_features, _labels, mode="eval", level=level) return _score # Create model with torch.no_grad(): model = model_cls(params).cuda() if args.half: model = model.half() if not params.monte_carlo: model.eval() model.load_state_dict( torch.load(utils.latest_checkpoint(args.checkpoint), map_location="cpu")["model"]) dataset = data.get_dataset(args.input, "eval", params) data_iter = iter(dataset) counter = 0 pad_max = 1024 # Buffers for synchronization size = torch.zeros([dist.get_world_size()]).long() if params.level == "sentence": t_list = [ torch.empty([params.decode_batch_size]).float() for _ in range(dist.get_world_size()) ] else: t_list = [ torch.empty([params.decode_batch_size, pad_max]).float() for _ in range(dist.get_world_size()) ] if dist.get_rank() == 0: fd = open(args.output, "w") else: fd = None while True: try: features = next(data_iter) features = data.lookup(features, "eval", params) batch_size = features[0]["source"].shape[0] except: features = { "source": torch.ones([1, 1]).long(), "source_mask": torch.ones([1, 1]).float(), "target": torch.ones([1, 1]).long(), "target_mask": torch.ones([1, 1]).float() }, torch.ones([1, 1]).long() batch_size = 0 t = time.time() counter += 1 scores = score_fn(features, model, params.level) # Padding if params.level == "sentence": pad_batch = params.decode_batch_size - scores.shape[0] scores = torch.nn.functional.pad(scores, [0, pad_batch]) else: pad_batch = params.decode_batch_size - scores.shape[0] pad_length = pad_max - scores.shape[1] scores = torch.nn.functional.pad(scores, (0, pad_length, 0, pad_batch), value=-1) # Synchronization size.zero_() size[dist.get_rank()].copy_(torch.tensor(batch_size)) dist.all_reduce(size) dist.all_gather(t_list, scores.float()) if size.sum() == 0: break if dist.get_rank() != 0: continue for i in range(params.decode_batch_size): for j in range(dist.get_world_size()): n = size[j] score = t_list[j][i] if i >= n: continue if params.level == "sentence": fd.write("{:.4f}\n".format(score)) else: s_list = score.tolist() for s in s_list: if s >= 0: fd.write("{:.8f} ".format(s)) else: fd.write("\n") break t = time.time() - t logging.info("Finished batch: %d (%.3f sec)" % (counter, t)) if dist.get_rank() == 0: fd.close()