Beispiel #1
0
def main(args):
    # Load configs
    model_cls_list = [models.get_model(model) for model in args.models]
    params_list = [default_params() for _ in range(len(model_cls_list))]
    params_list = [
        merge_params(params, model_cls.default_params())
        for params, model_cls in zip(params_list, model_cls_list)]
    params_list = [
        import_params(args.checkpoints[i], args.models[i], params_list[i])
        for i in range(len(args.checkpoints))]
    params_list = [
        override_params(params_list[i], args)
        for i in range(len(model_cls_list))]

    params = params_list[0]
    dist.init_process_group("nccl", init_method=args.url,
                            rank=args.local_rank,
                            world_size=len(params.device_list))
    torch.cuda.set_device(params.device_list[args.local_rank])
    torch.set_default_tensor_type(torch.cuda.FloatTensor)

    if args.half:
        torch.set_default_dtype(torch.half)
        torch.set_default_tensor_type(torch.cuda.HalfTensor)

    # Create model
    with torch.no_grad():
        model_list = []

        if len(args.input) == 1:
            mode = "infer"
            if params.from_torchtext:
                dataset = data.get_dataset_torchtext(args.input[0], mode, params)
            else:
                dataset = data.get_dataset(args.input[0], mode, params)
        else:
            # Teacher-forcing
            mode = "eval"
            if params.from_torchtext:
                dataset = data.get_dataset_torchtext(args.input, mode, params)
            else:
                dataset = data.get_dataset(args.input, mode, params)

        iterator = iter(dataset)
        idx = 0
        counter = 0
        pad_max = 1024
        top_beams = params.top_beams
        decode_batch_size = params.decode_batch_size

        # count eval dataset
        total_len = 0
        max_length = 0
        for sample in iterator:
            total_len += 1
            length = sample['source'].shape[1]
            if length > max_length:
                max_length = length
        iterator = iter(dataset)

        for param in params_list:
            if hasattr(param, "max_length"):
                param.max_length = min(param.max_length, max_length)
            else:
                param.max_length = max_length

        for i in range(len(args.models)):
            model = model_cls_list[i](params_list[i]).cuda()

            if args.half:
                model = model.half()

            model.eval()
            model.load_state_dict(
                torch.load(utils.latest_checkpoint(args.checkpoints[i]),
                           map_location="cpu")["model"])

            model_list.append(model)

        # Buffers for synchronization
        size = torch.zeros([dist.get_world_size()]).long()
        t_list = [torch.empty([decode_batch_size, top_beams, pad_max]).long()
                  for _ in range(dist.get_world_size())]

        if dist.get_rank() == 0:
            fd = open(args.output, "wb")
            pbar = tqdm(total=total_len)
            pbar.set_description("Translating to {}".format(args.output))
        else:
            fd = None

        states = [None for _ in model_list]
        if "cachedtransformer" in [model.name for model in model_list]:
            last_features = [None for _ in model_list]
        for model in model_list:
            if model.name == "cachedtransformer":
                model.encoder.cache.set_batch_size(params.decode_batch_size)
                model.decoder.cache.set_batch_size(params.decode_batch_size)

        while True:
            try:
                features = next(iterator)
                features = data.lookup(features, mode, params, from_torchtext=params.from_torchtext)

                if mode == "eval":
                    features = features[0]

                batch_size = features["source"].shape[0]
            except:
                features = {
                    "source": torch.ones([1, 1]).long(),
                    "source_mask": torch.ones([1, 1]).float()
                }

                if mode == "eval":
                    features["target"] = torch.ones([1, 1]).long()
                    features["target_mask"] = torch.ones([1, 1]).float()

                batch_size = 0
            finally:
                for im, model in enumerate(model_list):
                    if model.name == "cachedtransformer":
                        features = update_cache(model, features, states[im], last_features[im], evaluate=True)
                        last_features[im] = features

            counter += 1

            # Decode
            if mode != "eval":
                seqs, _, states = utils.beam_search(model_list, features, params)
            else:
                seqs, _ = utils.argmax_decoding(model_list, features, params)

            # Padding
            pad_batch = decode_batch_size - seqs.shape[0]
            pad_beams = top_beams - seqs.shape[1]
            pad_length = pad_max - seqs.shape[2]
            seqs = torch.nn.functional.pad(
                seqs, (0, pad_length, 0, pad_beams, 0, pad_batch))

            # Synchronization
            size.zero_()
            size[dist.get_rank()].copy_(torch.tensor(batch_size))
            dist.all_reduce(size)
            dist.all_gather(t_list, seqs)

            if size.sum() == 0:
                break

            if dist.get_rank() != 0:
                continue

            for i in range(decode_batch_size):
                for j in range(dist.get_world_size()):
                    for k in range(top_beams):
                        n = size[j]
                        seq = convert_to_string(t_list[j][i][k], params)

                        if i >= n:
                            continue

                        if top_beams == 1:
                            fd.write(seq)
                            fd.write(b"\n")
                        else:
                            fd.write(str(idx).encode("utf-8"))
                            fd.write(b"\t")
                            fd.write(str(k).encode("utf-8"))
                            fd.write(b"\t")
                            fd.write(seq)
                            fd.write(b"\n")

                    idx = idx + 1

            if dist.get_rank() == 0:
                pbar.update(1)

        if dist.get_rank() == 0:
            pbar.close()
            fd.close()
Beispiel #2
0
def main(args):
    # Load configs
    model_cls_list = [models.get_model(model) for model in args.models]
    params_list = [default_params() for _ in range(len(model_cls_list))]
    params_list = [
        merge_params(params, model_cls.default_params())
        for params, model_cls in zip(params_list, model_cls_list)
    ]
    params_list = [
        import_params(args.checkpoints[i], args.models[i], params_list[i])
        for i in range(len(args.checkpoints))
    ]
    params_list = [
        override_params(params_list[i], args)
        for i in range(len(model_cls_list))
    ]

    params = params_list[0]

    if args.cpu:
        dist.init_process_group("gloo",
                                init_method=args.url,
                                rank=args.local_rank,
                                world_size=1)
        torch.set_default_tensor_type(torch.FloatTensor)
    else:
        dist.init_process_group("nccl",
                                init_method=args.url,
                                rank=args.local_rank,
                                world_size=len(params.device_list))
        torch.cuda.set_device(params.device_list[args.local_rank])
        torch.set_default_tensor_type(torch.cuda.FloatTensor)

    if args.half:
        torch.set_default_dtype(torch.half)
        torch.set_default_tensor_type(torch.cuda.HalfTensor)

    # Create model
    with torch.no_grad():
        model_list = []

        for i in range(len(args.models)):
            if args.cpu:
                model = model_cls_list[i](params_list[i])
            else:
                model = model_cls_list[i](params_list[i]).cuda()

            if args.half:
                model = model.half()

            model.eval()
            model.load_state_dict(
                torch.load(utils.latest_checkpoint(args.checkpoints[i]),
                           map_location="cpu")["model"])

            model_list.append(model)

        if len(args.input) == 1:
            mode = "infer"
            sorted_key, dataset = data.get_dataset(args.input[0], mode, params)
        else:
            # Teacher-forcing
            mode = "eval"
            dataset = data.get_dataset(args.input, mode, params)
            sorted_key = None

        iterator = iter(dataset)
        counter = 0
        pad_max = 1024
        top_beams = params.top_beams
        decode_batch_size = params.decode_batch_size

        # Buffers for synchronization
        size = torch.zeros([dist.get_world_size()]).long()
        t_list = [
            torch.empty([decode_batch_size, top_beams, pad_max]).long()
            for _ in range(dist.get_world_size())
        ]

        all_outputs = []

        while True:
            try:
                features = next(iterator)
                features = data.lookup(features, mode, params, to_cpu=args.cpu)

                if mode == "eval":
                    features = features[0]

                batch_size = features["source"].shape[0]
            except:
                features = {
                    "source": torch.ones([1, 1]).long(),
                    "source_mask": torch.ones([1, 1]).float()
                }

                if mode == "eval":
                    features["target"] = torch.ones([1, 1]).long()
                    features["target_mask"] = torch.ones([1, 1]).float()

                batch_size = 0

            t = time.time()
            counter += 1

            # Decode
            if mode != "eval":
                seqs, _ = utils.beam_search(model_list, features, params)
            else:
                seqs, _ = utils.argmax_decoding(model_list, features, params)

            # Padding
            pad_batch = decode_batch_size - seqs.shape[0]
            pad_beams = top_beams - seqs.shape[1]
            pad_length = pad_max - seqs.shape[2]
            seqs = torch.nn.functional.pad(
                seqs, (0, pad_length, 0, pad_beams, 0, pad_batch))

            # Synchronization
            size.zero_()
            size[dist.get_rank()].copy_(torch.tensor(batch_size))

            if args.cpu:
                t_list[dist.get_rank()].copy_(seqs)
            else:
                dist.all_reduce(size)
                dist.all_gather(t_list, seqs)

            if size.sum() == 0:
                break

            if dist.get_rank() != 0:
                continue

            for i in range(decode_batch_size):
                for j in range(dist.get_world_size()):
                    beam_seqs = []
                    pad_flag = i >= size[j]
                    for k in range(top_beams):
                        seq = convert_to_string(t_list[j][i][k], params)

                        if pad_flag:
                            continue

                        beam_seqs.append(seq)

                    if pad_flag:
                        continue

                    all_outputs.append(beam_seqs)

            t = time.time() - t
            print("Finished batch: %d (%.3f sec)" % (counter, t))

        if dist.get_rank() == 0:
            restored_outputs = []
            if sorted_key is not None:
                for idx in range(len(all_outputs)):
                    restored_outputs.append(all_outputs[sorted_key[idx]])
            else:
                restored_outputs = all_outputs

            with open(args.output, "wb") as fd:
                if top_beams == 1:
                    for seqs in restored_outputs:
                        fd.write(seqs[0] + b"\n")
                else:
                    for idx, seqs in enumerate(restored_outputs):
                        for k, seq in enumerate(seqs):
                            fd.write(b"%d\t%d\t" % (idx, k))
                            fd.write(seq + b"\n")
Beispiel #3
0
def main(args):
    model_cls = models.get_model(args.model)

    # Import and override parameters
    # Priorities (low -> high):
    # default -> saved -> command
    params = default_params()
    params = merge_params(params, model_cls.default_params(args.hparam_set))
    params = import_params(args.output, args.model, params)
    params = override_params(params, args)

    # Initialize distributed utility
    if args.distributed:
        dist.init_process_group("nccl")
        torch.cuda.set_device(args.local_rank)
        torch.set_default_tensor_type(torch.cuda.FloatTensor)
    else:
        dist.init_process_group("nccl",
                                init_method=args.url,
                                rank=args.local_rank,
                                world_size=len(params.device_list))
        torch.cuda.set_device(params.device_list[args.local_rank])
        torch.set_default_tensor_type(torch.cuda.FloatTensor)

    # Export parameters
    if dist.get_rank() == 0:
        export_params(params.output, "params.json", params)
        export_params(params.output, "%s.json" % params.model,
                      collect_params(params, model_cls.default_params()))

    model = model_cls(params).cuda()

    if args.half:
        model = model.half()
        torch.set_default_dtype(torch.half)
        torch.set_default_tensor_type(torch.cuda.HalfTensor)

    model.train()

    # Init tensorboard
    summary.init(params.output, params.save_summary)

    schedule = get_learning_rate_schedule(params)
    clipper = get_clipper(params)

    if params.optimizer.lower() == "adam":
        optimizer = optimizers.AdamOptimizer(learning_rate=schedule,
                                             beta_1=params.adam_beta1,
                                             beta_2=params.adam_beta2,
                                             epsilon=params.adam_epsilon,
                                             clipper=clipper,
                                             summaries=params.save_summary)
    elif params.optimizer.lower() == "adadelta":
        optimizer = optimizers.AdadeltaOptimizer(
            learning_rate=schedule,
            rho=params.adadelta_rho,
            epsilon=params.adadelta_epsilon,
            clipper=clipper,
            summaries=params.save_summary)
    elif params.optimizer.lower() == "sgd":
        optimizer = optimizers.SGDOptimizer(learning_rate=schedule,
                                            clipper=clipper,
                                            summaries=params.save_summary)
    else:
        raise ValueError("Unknown optimizer %s" % params.optimizer)

    if args.half:
        optimizer = optimizers.LossScalingOptimizer(optimizer)

    optimizer = optimizers.MultiStepOptimizer(optimizer, params.update_cycle)

    if dist.get_rank() == 0:
        print_variables(model)

    if params.from_torchtext:
        dataset = data.get_dataset_torchtext(params.input, "train", params)
    else:
        dataset = data.get_dataset(params.input, "train", params)

    if params.validation:
        if params.from_torchtext:
            eval_dataset = data.get_dataset_torchtext(params.validation,
                                                      "infer", params)
        else:
            eval_dataset = data.get_dataset(params.validation, "infer", params)
        references = load_references(params.references)
    else:
        eval_dataset = None
        references = None

    # Load checkpoint
    checkpoint = utils.latest_checkpoint(params.output)

    if args.checkpoint is not None:
        # Load pre-trained models
        state = torch.load(args.checkpoint, map_location="cpu")
        model.load_state_dict(state["model"], strict=False)
        step = params.initial_step
        epoch = 0
        broadcast(model)
    elif checkpoint is not None:
        state = torch.load(checkpoint, map_location="cpu")
        step = state["step"]
        epoch = state["epoch"]
        model.load_state_dict(state["model"])

        if "optimizer" in state:
            optimizer.load_state_dict(state["optimizer"])
    else:
        step = 0
        epoch = 0
        broadcast(model)

    def train_fn(inputs):
        features, labels = inputs
        loss, state = model(features, labels)
        return loss, state

    counter = 0
    state = None
    if params.model == "cachedtransformer":
        last_feature = None

    while True:
        start_time = time.time()

        for features in dataset:
            if counter % params.update_cycle == 0:
                step += 1
                utils.set_global_step(step)

            counter += 1
            t = time.time()
            features = data.lookup(features,
                                   "train",
                                   params,
                                   from_torchtext=params.from_torchtext)
            if model.name == "cachedtransformer":
                features = utils.update_cache(model, features, state,
                                              last_feature)
                last_feature = features[0]
            loss, state = train_fn(features)
            gradients = optimizer.compute_gradients(loss,
                                                    list(model.parameters()))
            grads_and_vars = optimizers.exclude_variables(
                params.pattern, zip(gradients, list(model.named_parameters())))
            optimizer.apply_gradients(grads_and_vars)

            t = time.time() - t

            summary.scalar("loss", loss, step, write_every_n_steps=1)
            summary.scalar("global_step/sec", t, step)

            if counter % params.update_cycle == 0:
                if step > 0 and step % args.log_interval == 0:
                    elapsed = time.time() - start_time
                    print('| epoch {:2d} | step {:8d} | lr {:02.2e} | '
                          'ms/step {:4.0f} | loss {:8.4f} '.format(
                              epoch + 1, step,
                              optimizer._optimizer._learning_rate(step),
                              elapsed * 1000 / args.log_interval, loss.item()))
                    start_time = time.time()

                if step >= params.train_steps:
                    utils.evaluate(model, eval_dataset, params.output,
                                   references, params)
                    save_checkpoint(step, epoch, model, optimizer, params)

                    if dist.get_rank() == 0:
                        summary.close()

                    return

                if step % params.eval_steps == 0:
                    utils.evaluate(model, eval_dataset, params.output,
                                   references, params)
                    start_time = time.time()

                if step % params.save_checkpoint_steps == 0:
                    save_checkpoint(step, epoch, model, optimizer, params)
                    start_time = time.time()

        epoch += 1
Beispiel #4
0
def main(args):
    model_cls = models.get_model(args.model)

    # Import and override parameters
    # Priorities (low -> high):
    # default -> saved -> command
    params = default_params()
    params = merge_params(params, model_cls.default_params(args.hparam_set))
    params = import_params(args.output, args.model, params)
    params = override_params(params, args)

    # Initialize distributed utility
    if args.distributed:
        dist.init_process_group("nccl")
        torch.cuda.set_device(args.local_rank)
        torch.set_default_tensor_type(torch.cuda.FloatTensor)
    else:
        dist.init_process_group("nccl",
                                init_method=args.url,
                                rank=args.local_rank,
                                world_size=len(params.device_list))
        torch.cuda.set_device(params.device_list[args.local_rank])
        torch.set_default_tensor_type(torch.cuda.FloatTensor)

    # Export parameters
    if dist.get_rank() == 0:
        export_params(params.output, "params.json", params)
        export_params(params.output, "%s.json" % params.model,
                      collect_params(params, model_cls.default_params()))

    model = model_cls(params).cuda()

    if args.half:
        model = model.half()
        torch.set_default_dtype(torch.half)
        torch.set_default_tensor_type(torch.cuda.HalfTensor)

    model.train()

    # Init tensorboard
    summary.init(params.output, params.save_summary)

    schedule = get_learning_rate_schedule(params)
    clipper = get_clipper(params)
    optimizer = get_optimizer(params, schedule, clipper)

    if args.half:
        optimizer = optimizers.LossScalingOptimizer(optimizer)

    optimizer = optimizers.MultiStepOptimizer(optimizer, params.update_cycle)

    trainable_flags = print_variables(model, params.pattern,
                                      dist.get_rank() == 0)

    dataset = data.get_dataset(params.input, "train", params)

    if params.validation:
        sorted_key, eval_dataset = data.get_dataset(params.validation, "infer",
                                                    params)
        references = load_references(params.references)
    else:
        sorted_key = None
        eval_dataset = None
        references = None

    # Load checkpoint
    checkpoint = utils.latest_checkpoint(params.output)

    if args.checkpoint is not None:
        # Load pre-trained models
        state = torch.load(args.checkpoint, map_location="cpu")
        model.load_state_dict(state["model"])
        step = params.initial_step
        epoch = 0
        broadcast(model)
    elif checkpoint is not None:
        state = torch.load(checkpoint, map_location="cpu")
        step = state["step"]
        epoch = state["epoch"]
        model.load_state_dict(state["model"])

        if "optimizer" in state:
            optimizer.load_state_dict(state["optimizer"])
    else:
        step = 0
        epoch = 0
        broadcast(model)

    def train_fn(inputs):
        features, labels = inputs
        loss = model(features, labels)
        return loss

    counter = 0

    while True:
        for features in dataset:
            if counter % params.update_cycle == 0:
                step += 1
                utils.set_global_step(step)

            counter += 1
            t = time.time()
            features = data.lookup(features, "train", params)
            loss = train_fn(features)
            gradients = optimizer.compute_gradients(loss,
                                                    list(model.parameters()))
            grads_and_vars = exclude_variables(
                trainable_flags, zip(gradients,
                                     list(model.named_parameters())))
            optimizer.apply_gradients(grads_and_vars)

            t = time.time() - t

            summary.scalar("loss", loss, step, write_every_n_steps=1)
            summary.scalar("global_step/sec", t, step)

            print("epoch = %d, step = %d, loss = %.3f (%.3f sec)" %
                  (epoch + 1, step, float(loss), t))

            if counter % params.update_cycle == 0:
                if step >= params.train_steps:
                    utils.evaluate(model, sorted_key, eval_dataset,
                                   params.output, references, params)
                    save_checkpoint(step, epoch, model, optimizer, params)

                    if dist.get_rank() == 0:
                        summary.close()

                    return

                if step % params.eval_steps == 0:
                    utils.evaluate(model, sorted_key, eval_dataset,
                                   params.output, references, params)

                if step % params.save_checkpoint_steps == 0:
                    save_checkpoint(step, epoch, model, optimizer, params)

        epoch += 1
Beispiel #5
0
def main(args):
    model_cls = models.get_model(args.model)
    # Import and override parameters
    # Priorities (low -> high):
    # default -> saved -> command
    params = default_params()
    params = merge_params(params, model_cls.default_params())
    params = import_params(args.checkpoint, args.model, params)
    params = override_params(params, args)

    dist.init_process_group("nccl",
                            init_method=args.url,
                            rank=args.local_rank,
                            world_size=len(params.device_list))
    torch.cuda.set_device(params.device_list[args.local_rank])
    torch.set_default_tensor_type(torch.cuda.FloatTensor)

    if args.half:
        torch.set_default_dtype(torch.half)
        torch.set_default_tensor_type(torch.cuda.HalfTensor)

    def score_fn(inputs, _model, level="sentence"):
        _features, _labels = inputs
        _score = _model(_features, _labels, mode="eval", level=level)
        return _score

    # Create model
    with torch.no_grad():
        model = model_cls(params).cuda()

        if args.half:
            model = model.half()

        if not params.monte_carlo:
            model.eval()

        model.load_state_dict(
            torch.load(utils.latest_checkpoint(args.checkpoint),
                       map_location="cpu")["model"])
        dataset = data.get_dataset(args.input, "eval", params)
        data_iter = iter(dataset)
        counter = 0
        pad_max = 1024

        # Buffers for synchronization
        size = torch.zeros([dist.get_world_size()]).long()
        if params.level == "sentence":
            t_list = [
                torch.empty([params.decode_batch_size]).float()
                for _ in range(dist.get_world_size())
            ]
        else:
            t_list = [
                torch.empty([params.decode_batch_size, pad_max]).float()
                for _ in range(dist.get_world_size())
            ]

        if dist.get_rank() == 0:
            fd = open(args.output, "w")
        else:
            fd = None

        while True:
            try:
                features = next(data_iter)
                features = data.lookup(features, "eval", params)
                batch_size = features[0]["source"].shape[0]
            except:
                features = {
                    "source": torch.ones([1, 1]).long(),
                    "source_mask": torch.ones([1, 1]).float(),
                    "target": torch.ones([1, 1]).long(),
                    "target_mask": torch.ones([1, 1]).float()
                }, torch.ones([1, 1]).long()
                batch_size = 0

            t = time.time()
            counter += 1

            scores = score_fn(features, model, params.level)

            # Padding
            if params.level == "sentence":
                pad_batch = params.decode_batch_size - scores.shape[0]
                scores = torch.nn.functional.pad(scores, [0, pad_batch])
            else:
                pad_batch = params.decode_batch_size - scores.shape[0]
                pad_length = pad_max - scores.shape[1]
                scores = torch.nn.functional.pad(scores,
                                                 (0, pad_length, 0, pad_batch),
                                                 value=-1)

            # Synchronization
            size.zero_()
            size[dist.get_rank()].copy_(torch.tensor(batch_size))
            dist.all_reduce(size)
            dist.all_gather(t_list, scores.float())

            if size.sum() == 0:
                break

            if dist.get_rank() != 0:
                continue

            for i in range(params.decode_batch_size):
                for j in range(dist.get_world_size()):
                    n = size[j]
                    score = t_list[j][i]

                    if i >= n:
                        continue

                    if params.level == "sentence":
                        fd.write("{:.4f}\n".format(score))
                    else:
                        s_list = score.tolist()
                        for s in s_list:
                            if s >= 0:
                                fd.write("{:.8f} ".format(s))
                            else:
                                fd.write("\n")
                                break

            t = time.time() - t
            logging.info("Finished batch: %d (%.3f sec)" % (counter, t))

        if dist.get_rank() == 0:
            fd.close()