예제 #1
0
 def __init__(self, conf: XmuModelConf, vpack):
     super().__init__(conf)
     conf: XmuModelConf = self.conf
     self.vpack = vpack
     # =====
     # --
     # init their model
     model_cls = models.get_model("deepatt")
     # --
     params = default_params()
     params = merge_params(params, model_cls.default_params(None))
     # params = import_params(args.output, args.model, params)
     params = override_params(params, conf)
     # --
     self.params = params
     model = model_cls(params).cuda()
     model.load_embedding(params.embedding)
     # --
     self.embedding = data.load_glove_embedding(params.embedding)
     # =====
     # wrap their model
     self.M = ModuleWrapper(model, None)
     self.bio_helper = SeqSchemeHelperStr("BIO")
     # --
     zzz = self.optims  # finally build optim!
def main(args):
    # Load configs
    model_cls = models.get_model(args.model)
    params = default_params()
    params = merge_params(params, model_cls.default_params())
    params = import_params(args.checkpoint, args.model, params)
    params = override_params(params, args)
    torch.cuda.set_device(params.device)
    torch.set_default_tensor_type(torch.cuda.FloatTensor)

    # Create model
    with torch.no_grad():
        model = model_cls(params).cuda()

        if args.half:
            model = model.half()
            torch.set_default_tensor_type(torch.cuda.HalfTensor)

        model.eval()
        model.load_state_dict(
            torch.load(utils.best_checkpoint(args.checkpoint),
                       map_location="cpu")["model"])

        # Decoding
        dataset = data.get_dataset(args.input, "infer", params)
        fd = open(args.output, "wb")
        counter = 0

        if params.embedding:
            embedding = data.load_embedding(params.embedding)
        else:
            embedding = None

        for features in dataset:
            t = time.time()
            counter += 1
            features = data.lookup(features, "infer", params, embedding)

            labels = model.argmax_decode(features)
            batch = convert_to_string(features["inputs"], labels, params)
            del features
            del labels

            for seq in batch:
                fd.write(seq)
                fd.write(b"\n")

            t = time.time() - t
            print("Finished batch: %d (%.3f sec)" % (counter, t))

        del dataset
        fd.flush()
        fd.close()
예제 #3
0
def get_model(args):
    model_cls = models.get_model(args.model)

    params = default_params()
    params = merge_params(params, model_cls.default_params())
    params = merge_params(params, predictor.default_params())
    params = import_params(args.dir, args.model, params)
    params.decode_batch_size = 1
    src_vocab, src_w2idx, src_idx2w = data.load_vocabulary(params.vocab[0])
    tgt_vocab, tgt_w2idx, tgt_idx2w = data.load_vocabulary(params.vocab[1])

    params.vocabulary = {"source": src_vocab, "target": tgt_vocab}
    params.lookup = {"source": src_w2idx, "target": tgt_w2idx}
    params.mapping = {"source": src_idx2w, "target": tgt_idx2w}

    torch.cuda.set_device(0)
    torch.set_default_tensor_type(torch.cuda.FloatTensor)

    # Create model
    model = model_cls(params).cuda()
    return model, params
def main(args):
    model_cls = models.get_model(args.model)

    # Import and override parameters
    # Priorities (low -> high):
    # default -> saved -> command
    params = default_params()
    params = merge_params(params, model_cls.default_params(args.hparam_set))
    params = import_params(args.output, args.model, params)
    params = override_params(params, args)

    # Initialize distributed utility
    if args.distributed:
        dist.init_process_group("nccl")
        torch.cuda.set_device(args.local_rank)
    else:
        dist.init_process_group("nccl",
                                init_method=args.url,
                                rank=args.local_rank,
                                world_size=len(params.device_list))
        torch.cuda.set_device(params.device_list[args.local_rank])
        torch.set_default_tensor_type(torch.cuda.FloatTensor)

    # Export parameters
    if dist.get_rank() == 0:
        export_params(params.output, "params.json", params)
        export_params(params.output, "%s.json" % params.model,
                      collect_params(params, model_cls.default_params()))

    model = model_cls(params).cuda()
    model.load_embedding(params.embedding)

    if args.half:
        model = model.half()
        torch.set_default_dtype(torch.half)
        torch.set_default_tensor_type(torch.cuda.HalfTensor)

    model.train()

    # Init tensorboard
    summary.init(params.output, params.save_summary)
    schedule = get_learning_rate_schedule(params)
    clipper = get_clipper(params)

    if params.optimizer.lower() == "adam":
        optimizer = optimizers.AdamOptimizer(learning_rate=schedule,
                                             beta_1=params.adam_beta1,
                                             beta_2=params.adam_beta2,
                                             epsilon=params.adam_epsilon,
                                             clipper=clipper)
    elif params.optimizer.lower() == "adadelta":
        optimizer = optimizers.AdadeltaOptimizer(
            learning_rate=schedule,
            rho=params.adadelta_rho,
            epsilon=params.adadelta_epsilon,
            clipper=clipper)
    else:
        raise ValueError("Unknown optimizer %s" % params.optimizer)

    if args.half:
        optimizer = optimizers.LossScalingOptimizer(optimizer)

    optimizer = optimizers.MultiStepOptimizer(optimizer, params.update_cycle)

    if dist.get_rank() == 0:
        print_variables(model)

    dataset = data.get_dataset(params.input, "train", params)

    # Load checkpoint
    checkpoint = utils.latest_checkpoint(params.output)

    if checkpoint is not None:
        state = torch.load(checkpoint, map_location="cpu")
        step = state["step"]
        epoch = state["epoch"]
        model.load_state_dict(state["model"])

        if "optimizer" in state:
            optimizer.load_state_dict(state["optimizer"])
    else:
        step = 0
        epoch = 0
        broadcast(model)

    def train_fn(inputs):
        features, labels = inputs
        loss = model(features, labels)
        return loss

    counter = 0
    should_save = False

    if params.script:
        thread = ValidationWorker(daemon=True)
        thread.init(params)
        thread.start()
    else:
        thread = None

    def step_fn(features, step):
        t = time.time()
        features = data.lookup(features, "train", params)
        loss = train_fn(features)
        gradients = optimizer.compute_gradients(loss, list(model.parameters()))
        if params.clip_grad_norm:
            torch.nn.utils.clip_grad_norm_(model.parameters(),
                                           params.clip_grad_norm)

        optimizer.apply_gradients(
            zip(gradients, list(model.named_parameters())))

        t = time.time() - t

        summary.scalar("loss", loss, step, write_every_n_steps=1)
        summary.scalar("global_step/sec", t, step)

        print("epoch = %d, step = %d, loss = %.3f (%.3f sec)" %
              (epoch + 1, step, float(loss), t))

    try:
        while True:
            for features in dataset:
                if counter % params.update_cycle == 0:
                    step += 1
                    utils.set_global_step(step)
                    should_save = True

                counter += 1
                step_fn(features, step)

                if step % params.save_checkpoint_steps == 0:
                    if should_save:
                        save_checkpoint(step, epoch, model, optimizer, params)
                        should_save = False

                if step >= params.train_steps:
                    if should_save:
                        save_checkpoint(step, epoch, model, optimizer, params)

                    if dist.get_rank() == 0:
                        summary.close()

                    return

            epoch += 1
    finally:
        if thread is not None:
            thread.stop()
            thread.join()