Exemplo n.º 1
0
def main():
    # Loads vocab.
    vocab = make_vocab("data/ptb.train.txt")
    print("#vocab:", len(vocab))  # maybe 10000
    eos_id = vocab["<s>"]

    # Loads all corpus.
    train_corpus = load_corpus("data/ptb.train.txt", vocab)
    valid_corpus = load_corpus("data/ptb.valid.txt", vocab)
    num_train_sents = len(train_corpus)
    num_valid_sents = len(valid_corpus)
    num_train_labels = count_labels(train_corpus)
    num_valid_labels = count_labels(valid_corpus)
    print("train:", num_train_sents, "sentences,", num_train_labels, "labels")
    print("valid:", num_valid_sents, "sentences,", num_valid_labels, "labels")

    # Device and computation graph.
    dev = D.CUDA(0)
    Device.set_default(dev)
    g = Graph()
    Graph.set_default(g)

    # Our LM.
    lm = RNNLM(len(vocab), eos_id)

    # Optimizer.
    optimizer = O.SGD(1)
    #optimizer.set_weight_decay(1e-6)
    optimizer.set_gradient_clipping(5)
    optimizer.add(lm)

    # Sentence IDs.
    train_ids = list(range(num_train_sents))
    valid_ids = list(range(num_valid_sents))

    best_valid_ppl = 1e10

    # Train/valid loop.
    for epoch in range(MAX_EPOCH):
        print("epoch", epoch + 1, "/", MAX_EPOCH, ":")
        # Shuffles train sentence IDs.
        random.shuffle(train_ids)

        # Training.
        train_loss = 0
        for ofs in range(0, num_train_sents, BATCH_SIZE):
            batch_ids = train_ids[ofs:min(ofs + BATCH_SIZE, num_train_sents)]
            batch = make_batch(train_corpus, batch_ids, eos_id)

            g.clear()

            outputs = lm.forward(batch, True)
            loss = lm.loss(outputs, batch)
            train_loss += loss.to_float() * len(batch_ids)

            optimizer.reset_gradients()
            loss.backward()
            optimizer.update()

            print("%d" % ofs, end="\r")
            sys.stdout.flush()

        train_ppl = math.exp(train_loss / num_train_labels)
        print("  train ppl =", train_ppl)

        # Validation.
        valid_loss = 0
        for ofs in range(0, num_valid_sents, BATCH_SIZE):
            batch_ids = valid_ids[ofs:min(ofs + BATCH_SIZE, num_valid_sents)]
            batch = make_batch(valid_corpus, batch_ids, eos_id)

            g.clear()

            outputs = lm.forward(batch, False)
            loss = lm.loss(outputs, batch)
            valid_loss += loss.to_float() * len(batch_ids)
            print("%d" % ofs, end="\r")
            sys.stdout.flush()

        valid_ppl = math.exp(valid_loss / num_valid_labels)
        print("  valid ppl =", valid_ppl)

        if valid_ppl < best_valid_ppl:
            best_valid_ppl = valid_ppl
            print("  BEST")
        else:
            old_lr = optimizer.get_learning_rate_scaling()
            new_lr = 0.5 * old_lr
            optimizer.set_learning_rate_scaling(new_lr)
            print("  learning rate scaled:", old_lr, "->", new_lr)
Exemplo n.º 2
0
def train(args):
    global TRAIN_BATCH_SIZE, LEARNING_RATE

    TRAIN_BATCH_SIZE = args.batch if args.batch else TRAIN_BATCH_SIZE
    print("train batch size", TRAIN_BATCH_SIZE)
    LEARNING_RATE = args.lr if args.lr else LEARNING_RATE
    print("learning_rate", LEARNING_RATE)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    n_gpu = torch.cuda.device_count()
    print(device, n_gpu)

    model, tokenizer = load_model_and_tokenizer(lm=args.lm,
                                                model_dir=args.model_dir)

    all_exs = []
    all_data = []

    if not args.num_examples:
        args.num_examples = [None] * len(args.data)
    else:
        args.num_examples += [None] * (len(args.data) - len(args.num_examples))
    print(args.data, args.num_examples)

    for data_source, num_exs in zip(args.data, args.num_examples):
        exs, data = get_train_data(data_source, tokenizer, lm=args.lm,
                       num_examples=num_exs, mask=args.mask, distant_source=args.distant_source)
        all_exs.append(exs)
        all_data.append(data)

    '''
    if args.unsup:
      if args.unsup.endswith(".pkl"):
        inputs = pickle.load(open(args.unsup, 'rb'))
        u_exs = inputs['exs']
        u_data = inputs['old_data']
        u_new_data = inputs['new_data']
      else:
        assert args.unsup in set(["matres", "udst"])
        u_exs, u_data = get_train_data(args.unsup, lm=arg.lm, num_examples=args.unsup_num_examples, mask=args.mask)  
      print(len(u_exs), "unsup examples loaded")
      UNSUP_BATCH_SIZE = args.unsup_batch if args.unsup_batch else int(TRAIN_BATCH_SIZE/2)
      uda_dataset = UdaDataset(u_exs, UNSUP_BATCH_SIZE)
    '''

    OUTPUT_DIR = args.output_dir if args.output_dir else "models/scratch/"
    if not os.path.exists(OUTPUT_DIR):
        os.makedirs(OUTPUT_DIR)

    if len(all_exs) == 0:
        print("no dataset specified")
    elif len(all_exs) == 1:
        print("one dataset specified")
        exs = all_exs[0]
        data = all_data[0]
    else:
        print("using multiple data sources")
        inputs = []
        for i in range(len(all_data[0].tensors)):
            inputs.append(torch.cat([d.tensors[i] for d in all_data]))

        exs = list(chain(*all_exs))
        data = TensorDataset(*inputs)


    data_sampler = RandomSampler(data)
    dataloader = DataLoader(data, sampler=data_sampler,
                            batch_size=TRAIN_BATCH_SIZE)

    print(len(data), len(exs), "examples loaded")

    num_train_optimization_steps = int(
        len(data) / TRAIN_BATCH_SIZE / GRADIENT_ACCUMULATION_STEPS) * NUM_TRAIN_EPOCHS
    print(num_train_optimization_steps, "optimization steps")
    num_warmup_steps = WARMUP_PROPORTION * num_train_optimization_steps

    model.to(device)

    # Prepare optimizer
    param_optimizer = list(model.named_parameters())

    # hack to remove pooler, which is not used
    # thus it produce None grad that break apex
    param_optimizer = [n for n in param_optimizer if 'pooler' not in n[0]]

    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(
        nd in n for nd in no_decay)], 'weight_decay': 0.01},
        {'params': [p for n, p in param_optimizer if any(
        nd in n for nd in no_decay)], 'weight_decay': 0.0}
    ]

    optimizer = AdamW(optimizer_grouped_parameters,
              lr=LEARNING_RATE,
              correct_bias=False)
    scheduler = get_linear_schedule_with_warmup(optimizer,
                            num_warmup_steps=num_warmup_steps,
                            num_training_steps=num_train_optimization_steps)  # PyTorch scheduler

    if args.serialize:
        inputs = {"exs": exs, "data": data}
        pickle.dump(inputs, open(OUTPUT_DIR+"inputs.pkl", 'wb'))

    logfile = open(OUTPUT_DIR + "/log.txt", "w+")
    print(args, file=logfile)
    print(len(data), len(exs), "examples loaded", file=logfile)
    count_labels(exs, file=logfile)
    count_labels(exs)
    print("learning_rate", LEARNING_RATE, file=logfile)

    global_step = 0
    num_epochs = args.epochs if not args.epochs is None else int(NUM_TRAIN_EPOCHS)
    if num_epochs == 0:
        exit()
    model.train()
    exs_cpy = exs

    for ep in trange(num_epochs, desc="Epoch"):
        last_loss_kldiv = 0
        for step, batch in enumerate(tqdm(dataloader,
                          desc="Iteration " + str(ep),
                          disable=args.disable_tqdm)):
            bbatch = tuple(t.to(device) for t in batch)
            loss, _, _ = model(*bbatch)

            '''
            if args.unsup:
                loss_kldiv = unsup_loss(model)
                if loss_kldiv:
                last_loss_kldiv = loss_kldiv.item()
                loss += loss_kldiv
            '''

            loss.backward()
            if step % 100 == 0:
                print("Loss: %.3f at step %d" % (loss.item(), step), file=logfile)
                # if args.unsup and last_loss_kldiv:
                #  print("Unsup Loss: %.3f at step %d" %(last_loss_kldiv, step), file=logfile)
            optimizer.step()
            scheduler.step()
            model.zero_grad()
            global_step += 1

            # Save a trained model, configuration and tokenizer
            model_output_dir = OUTPUT_DIR + "/output_" + str(ep) + "/"
            if not os.path.exists(model_output_dir):
                os.makedirs(model_output_dir)
            model.save_pretrained(model_output_dir)
            tokenizer.save_pretrained(model_output_dir)
Exemplo n.º 3
0
def main():
    os.environ['CUDA_VISIBLE_DEVICES'] = '1'
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

    size = (192, 160)
    origin_size = (240, 240)
    filter_NET = 500
    resize_method = 'resize'
    mixed_precision = False
    summary_path = 'summaries/20200829_020604'
    std_path = None
    pre_scaling = True
    # test_tfr_path = r'E:\Shared folder\BraTS\tfr\3d\test\test_2020_pre_scaling(2020_train_landmarks)\*'
    test_tfr_path = r'E:\Shared folder\BraTS\tfr\3d\test\test_2020_pre_scaling(2020_train_landmarks)\*'
    transform_mean_std_path = None
    base_path = r'E:\Shared folder\BraTS\2020\MICCAI_BraTS2020_TrainingData\BraTS20_Training_001\BraTS20_Training_001_seg.nii.gz'
    # test_tfr_path = r'E:\Shared folder\BraTS\tfr\test\2019\*.tfrecord'
    # test_tfr_path = r'E:\Shared folder\BraTS\tfr\non_zero_non_scaling_2019\valid_concat\*_39.*'

    gpus = tf.config.experimental.list_physical_devices('GPU')
    if gpus:
        try:
            for gpu in gpus:
                tf.config.experimental.set_memory_growth(gpu, True)
                print(f'Setting {gpu} successfully')
        except RuntimeError as e:
            print(e)

    if mixed_precision:
        policy = tf.keras.mixed_precision.experimental.Policy('mixed_float16')
        tf.keras.mixed_precision.experimental.set_policy(policy)
        print('Compute dtype: %s' % policy.compute_dtype)
        print('Variable dtype: %s' % policy.variable_dtype)

    # cast tfr to dataset api
    with tf.name_scope('loader'):
        # is_training = False >>> stop preprocessing
        # test_loader = load_data.tf_dataset(test_tfr_path,
        #                                    std_path,
        #                                    preprocessing=False,
        #                                    batch=batch,
        #                                    shuffle=False,
        #                                    repeat=1)
        test_loader = load_test_data.tf_dataset(
            test_tfr_path,
            None,
            transform_mean_std_path=transform_mean_std_path,
            scaling_path=std_path,
            pre_scaling=pre_scaling,
            preprocessing=False,
            batch=1,
            shuffle=False,
            repeat=1,
            size=size,
            resize_method=resize_method)

    net = model.UNet()
    # eager mode
    @tf.function
    def test_step(x):
        pred = net(x, training=False)
        s_pred = tf.squeeze(pred, axis=0)
        resize_pred = tf.image.resize(s_pred, origin_size)
        result = metrics.logits_2_label(resize_pred, testing=True)
        return result

    # data and labels generator
    test_generator = test_loader.generator()

    step = tf.Variable(0, trainable=False)

    # saver
    ckpt = tf.train.Checkpoint(net=net, step=step)
    ckpt_manager = tf.train.CheckpointManager(ckpt, summary_path, 5)
    latest_path = ckpt_manager.latest_checkpoint
    if latest_path:
        print('-----------------------Restoring: {}-----------------------'.
              format(latest_path))
        ckpt.restore(latest_path)

    n_step = step.numpy()
    print('--------------------Evaluation--------------------\n')
    now_time = time.strftime('Step_{}_%Y%m%d_%H%M'.format(n_step),
                             time.localtime())
    nii_path = os.path.join(summary_path, now_time + '_results')
    if not os.path.exists(nii_path):
        os.makedirs(nii_path)

    results_string = ''
    for i, (x, name) in tqdm(enumerate(test_generator)):
        result = test_step(x)
        pad_result = utils.fixed_depth_crop(result, False)
        filter_data, string = utils.count_labels(pad_result, filter_NET)
        results_string += '{}: {}\n'.format(name[0], string)
        print('{}: {}'.format(name[0], string))
        utils.save_itk(
            filter_data,
            nii_path + '/{}.nii.gz'.format(name[0].numpy().decode()),
            base_path)
    print('--------------------------------------------------')
    with open(os.path.join(nii_path, 'results.txt'), 'w') as f:
        f.write(results_string)
Exemplo n.º 4
0
def train(encdec, optimizer, prefix, best_valid_ppl):
    # Registers all parameters to the optimizer.
    optimizer.add_model(encdec)

    # Loads vocab.
    src_vocab = make_vocab(SRC_TRAIN_FILE, SRC_VOCAB_SIZE)
    trg_vocab = make_vocab(TRG_TRAIN_FILE, TRG_VOCAB_SIZE)
    inv_trg_vocab = make_inv_vocab(trg_vocab)
    print("#src_vocab:", len(src_vocab))
    print("#trg_vocab:", len(trg_vocab))

    # Loads all corpus
    train_src_corpus = load_corpus(SRC_TRAIN_FILE, src_vocab)
    train_trg_corpus = load_corpus(TRG_TRAIN_FILE, trg_vocab)
    valid_src_corpus = load_corpus(SRC_VALID_FILE, src_vocab)
    valid_trg_corpus = load_corpus(TRG_VALID_FILE, trg_vocab)
    test_src_corpus = load_corpus(SRC_TEST_FILE, src_vocab)
    test_ref_corpus = load_corpus_ref(REF_TEST_FILE, trg_vocab)
    num_train_sents = len(train_trg_corpus)
    num_valid_sents = len(valid_trg_corpus)
    num_test_sents = len(test_ref_corpus)
    num_train_labels = count_labels(train_trg_corpus)
    num_valid_labels = count_labels(valid_trg_corpus)
    print("train:", num_train_sents, "sentences,", num_train_labels, "labels")
    print("valid:", num_valid_sents, "sentences,", num_valid_labels, "labels")

    # Sentence IDs
    train_ids = list(range(num_train_sents))
    valid_ids = list(range(num_valid_sents))

    # Train/valid loop.
    for epoch in range(MAX_EPOCH):
        # Computation graph.
        g = Graph()
        Graph.set_default(g)

        print("epoch %d/%d:" % (epoch + 1, MAX_EPOCH))
        print("  learning rate scale = %.4e" %
              optimizer.get_learning_rate_scaling())

        # Shuffles train sentence IDs.
        random.shuffle(train_ids)

        # Training.
        train_loss = 0.
        for ofs in range(0, num_train_sents, BATCH_SIZE):
            print("%d" % ofs, end="\r")
            sys.stdout.flush()

            batch_ids = train_ids[ofs:min(ofs + BATCH_SIZE, num_train_sents)]
            src_batch = make_batch(train_src_corpus, batch_ids, src_vocab)
            trg_batch = make_batch(train_trg_corpus, batch_ids, trg_vocab)

            g.clear()
            encdec.encode(src_batch, True)
            loss = encdec.loss(trg_batch, True)
            train_loss += loss.to_float() * len(batch_ids)

            optimizer.reset_gradients()
            loss.backward()
            optimizer.update()

        train_ppl = math.exp(train_loss / num_train_labels)
        print("  train PPL = %.4f" % train_ppl)

        # Validation.
        valid_loss = 0.
        for ofs in range(0, num_valid_sents, BATCH_SIZE):
            print("%d" % ofs, end="\r")
            sys.stdout.flush()

            batch_ids = valid_ids[ofs:min(ofs + BATCH_SIZE, num_valid_sents)]
            src_batch = make_batch(valid_src_corpus, batch_ids, src_vocab)
            trg_batch = make_batch(valid_trg_corpus, batch_ids, trg_vocab)

            g.clear()
            encdec.encode(src_batch, False)
            loss = encdec.loss(trg_batch, False)
            valid_loss += loss.to_float() * len(batch_ids)

        valid_ppl = math.exp(valid_loss / num_valid_labels)
        print("  valid PPL = %.4f" % valid_ppl)

        # Calculates test BLEU.
        stats = defaultdict(int)
        for ofs in range(0, num_test_sents, BATCH_SIZE):
            print("%d" % ofs, end="\r")
            sys.stdout.flush()

            src_batch = test_src_corpus[ofs:min(ofs +
                                                BATCH_SIZE, num_test_sents)]
            ref_batch = test_ref_corpus[ofs:min(ofs +
                                                BATCH_SIZE, num_test_sents)]

            hyp_ids = test_batch(encdec, src_vocab, trg_vocab, src_batch)
            for hyp_line, ref_line in zip(hyp_ids, ref_batch):
                for k, v in get_bleu_stats(ref_line[1:-1], hyp_line).items():
                    stats[k] += v

        bleu = calculate_bleu(stats)
        print("  test BLEU = %.2f" % (100 * bleu))

        # Saves best model/optimizer.
        if valid_ppl < best_valid_ppl:
            best_valid_ppl = valid_ppl
            print("  saving model/optimizer ... ", end="")
            sys.stdout.flush()
            encdec.save(prefix + ".model")
            optimizer.save(prefix + ".optimizer")
            save_ppl(prefix + ".valid_ppl", best_valid_ppl)
            print("done.")
        else:
            # Learning rate decay by 1/sqrt(2)
            new_scale = .7071 * optimizer.get_learning_rate_scaling()
            optimizer.set_learning_rate_scaling(new_scale)