def __init__(self, conf: XmuModelConf, vpack): super().__init__(conf) conf: XmuModelConf = self.conf self.vpack = vpack # ===== # -- # init their model model_cls = models.get_model("deepatt") # -- params = default_params() params = merge_params(params, model_cls.default_params(None)) # params = import_params(args.output, args.model, params) params = override_params(params, conf) # -- self.params = params model = model_cls(params).cuda() model.load_embedding(params.embedding) # -- self.embedding = data.load_glove_embedding(params.embedding) # ===== # wrap their model self.M = ModuleWrapper(model, None) self.bio_helper = SeqSchemeHelperStr("BIO") # -- zzz = self.optims # finally build optim!
def main(args): # Load configs model_cls = models.get_model(args.model) params = default_params() params = merge_params(params, model_cls.default_params()) params = import_params(args.checkpoint, args.model, params) params = override_params(params, args) torch.cuda.set_device(params.device) torch.set_default_tensor_type(torch.cuda.FloatTensor) # Create model with torch.no_grad(): model = model_cls(params).cuda() if args.half: model = model.half() torch.set_default_tensor_type(torch.cuda.HalfTensor) model.eval() model.load_state_dict( torch.load(utils.best_checkpoint(args.checkpoint), map_location="cpu")["model"]) # Decoding dataset = data.get_dataset(args.input, "infer", params) fd = open(args.output, "wb") counter = 0 if params.embedding: embedding = data.load_embedding(params.embedding) else: embedding = None for features in dataset: t = time.time() counter += 1 features = data.lookup(features, "infer", params, embedding) labels = model.argmax_decode(features) batch = convert_to_string(features["inputs"], labels, params) del features del labels for seq in batch: fd.write(seq) fd.write(b"\n") t = time.time() - t print("Finished batch: %d (%.3f sec)" % (counter, t)) del dataset fd.flush() fd.close()
def get_model(args): model_cls = models.get_model(args.model) params = default_params() params = merge_params(params, model_cls.default_params()) params = merge_params(params, predictor.default_params()) params = import_params(args.dir, args.model, params) params.decode_batch_size = 1 src_vocab, src_w2idx, src_idx2w = data.load_vocabulary(params.vocab[0]) tgt_vocab, tgt_w2idx, tgt_idx2w = data.load_vocabulary(params.vocab[1]) params.vocabulary = {"source": src_vocab, "target": tgt_vocab} params.lookup = {"source": src_w2idx, "target": tgt_w2idx} params.mapping = {"source": src_idx2w, "target": tgt_idx2w} torch.cuda.set_device(0) torch.set_default_tensor_type(torch.cuda.FloatTensor) # Create model model = model_cls(params).cuda() return model, params
def main(args): model_cls = models.get_model(args.model) # Import and override parameters # Priorities (low -> high): # default -> saved -> command params = default_params() params = merge_params(params, model_cls.default_params(args.hparam_set)) params = import_params(args.output, args.model, params) params = override_params(params, args) # Initialize distributed utility if args.distributed: dist.init_process_group("nccl") torch.cuda.set_device(args.local_rank) else: dist.init_process_group("nccl", init_method=args.url, rank=args.local_rank, world_size=len(params.device_list)) torch.cuda.set_device(params.device_list[args.local_rank]) torch.set_default_tensor_type(torch.cuda.FloatTensor) # Export parameters if dist.get_rank() == 0: export_params(params.output, "params.json", params) export_params(params.output, "%s.json" % params.model, collect_params(params, model_cls.default_params())) model = model_cls(params).cuda() model.load_embedding(params.embedding) if args.half: model = model.half() torch.set_default_dtype(torch.half) torch.set_default_tensor_type(torch.cuda.HalfTensor) model.train() # Init tensorboard summary.init(params.output, params.save_summary) schedule = get_learning_rate_schedule(params) clipper = get_clipper(params) if params.optimizer.lower() == "adam": optimizer = optimizers.AdamOptimizer(learning_rate=schedule, beta_1=params.adam_beta1, beta_2=params.adam_beta2, epsilon=params.adam_epsilon, clipper=clipper) elif params.optimizer.lower() == "adadelta": optimizer = optimizers.AdadeltaOptimizer( learning_rate=schedule, rho=params.adadelta_rho, epsilon=params.adadelta_epsilon, clipper=clipper) else: raise ValueError("Unknown optimizer %s" % params.optimizer) if args.half: optimizer = optimizers.LossScalingOptimizer(optimizer) optimizer = optimizers.MultiStepOptimizer(optimizer, params.update_cycle) if dist.get_rank() == 0: print_variables(model) dataset = data.get_dataset(params.input, "train", params) # Load checkpoint checkpoint = utils.latest_checkpoint(params.output) if checkpoint is not None: state = torch.load(checkpoint, map_location="cpu") step = state["step"] epoch = state["epoch"] model.load_state_dict(state["model"]) if "optimizer" in state: optimizer.load_state_dict(state["optimizer"]) else: step = 0 epoch = 0 broadcast(model) def train_fn(inputs): features, labels = inputs loss = model(features, labels) return loss counter = 0 should_save = False if params.script: thread = ValidationWorker(daemon=True) thread.init(params) thread.start() else: thread = None def step_fn(features, step): t = time.time() features = data.lookup(features, "train", params) loss = train_fn(features) gradients = optimizer.compute_gradients(loss, list(model.parameters())) if params.clip_grad_norm: torch.nn.utils.clip_grad_norm_(model.parameters(), params.clip_grad_norm) optimizer.apply_gradients( zip(gradients, list(model.named_parameters()))) t = time.time() - t summary.scalar("loss", loss, step, write_every_n_steps=1) summary.scalar("global_step/sec", t, step) print("epoch = %d, step = %d, loss = %.3f (%.3f sec)" % (epoch + 1, step, float(loss), t)) try: while True: for features in dataset: if counter % params.update_cycle == 0: step += 1 utils.set_global_step(step) should_save = True counter += 1 step_fn(features, step) if step % params.save_checkpoint_steps == 0: if should_save: save_checkpoint(step, epoch, model, optimizer, params) should_save = False if step >= params.train_steps: if should_save: save_checkpoint(step, epoch, model, optimizer, params) if dist.get_rank() == 0: summary.close() return epoch += 1 finally: if thread is not None: thread.stop() thread.join()