コード例 #1
0
    saved_epoch = 0
    lr = run.lr

    l1_loss = torch.nn.L1Loss()

    if os.path.isfile(PATH):
        checkpoint = torch.load(PATH)
        network.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        saved_epoch = checkpoint['epoch']
        lr = checkpoint['lr']
        # loss = checkpoint['loss']
    m.begin_run(run, network, loader, validation_loader)

    for epoch in range(saved_epoch, 150):
        m.begin_epoch()

        if epoch == 30:
            for param_group in optimizer.param_groups:
                param_group['lr'] = 0.0001
        if epoch > 70:
            for param_group in optimizer.param_groups:
                param_group['lr'] = 0.00002
        for batch in tqdm(loader):

            # Get data
            in_features, t_slf = batch
            in_features = in_features.to(run.device)
            t_slf = t_slf.to(run.device)
            preds, mu, logvar = network(in_features)
コード例 #2
0
ファイル: trainer.py プロジェクト: AIMedLab/TransICD
def train(model, train_set, dev_set, test_set, hyper_params, batch_size,
          device):
    train_loader = DataLoader(train_set,
                              batch_size=batch_size,
                              shuffle=True,
                              num_workers=1)
    m = RunManager()
    optimizer = optim.AdamW(model.parameters(), lr=hyper_params.learning_rate)

    logging.info("Training Started...")
    m.begin_run(hyper_params, model, train_loader)
    for epoch in range(hyper_params.num_epoch):
        m.begin_epoch(epoch + 1)
        model.train()
        for batch in train_loader:
            texts = batch['text']
            lens = batch['length']
            targets = batch['codes']

            texts = texts.to(device)
            targets = targets.to(device)
            outputs, ldam_outputs, _ = model(texts, targets)

            if ldam_outputs is not None:
                loss = F.binary_cross_entropy_with_logits(
                    ldam_outputs, targets)
            else:
                loss = F.binary_cross_entropy_with_logits(outputs, targets)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            m.track_loss(loss)
            # m.track_num_correct(preds, affinities)

        m.end_epoch()
    m.end_run()
    hype = '_'.join([f'{k}_{v}' for k, v in hyper_params._asdict().items()])
    m.save(f'../results/train_results_{hype}')
    logging.info("Training finished.\n")

    # Training
    train_loader = DataLoader(train_set,
                              batch_size=batch_size,
                              shuffle=True,
                              num_workers=1)
    probabs, targets, _, _ = evaluate(model,
                                      train_loader,
                                      device,
                                      dtset='train')
    compute_scores(probabs, targets, hyper_params, dtset='train')

    # Validation
    dev_loader = DataLoader(dev_set,
                            batch_size=batch_size,
                            shuffle=True,
                            num_workers=1)
    probabs, targets, _, _ = evaluate(model, dev_loader, device, dtset='dev')
    compute_scores(probabs, targets, hyper_params, dtset='dev')

    # test_dataset
    test_loader = DataLoader(test_set,
                             batch_size=batch_size,
                             shuffle=True,
                             num_workers=1)
    probabs, targets, full_hadm_ids, full_attn_weights = evaluate(model,
                                                                  test_loader,
                                                                  device,
                                                                  dtset='test')
    compute_scores(probabs,
                   targets,
                   hyper_params,
                   dtset='test',
                   full_hadm_ids=full_hadm_ids,
                   full_attn_weights=full_attn_weights)