for i,fold in enumerate(tqdm(conf.valid_fold)):
    file_name = "bat_"+str(conf.scale)+"_s_"+str(conf.input_rows)+"x"+str(conf.input_cols)+"x"+str(conf.input_deps)+"_"+str(fold)+".npy"
    s = np.load(os.path.join(conf.data, file_name))
    x_valid.extend(s)
x_valid = np.expand_dims(np.array(x_valid), axis=1)

print("x_train: {} | {:.2f} ~ {:.2f}".format(x_train.shape, np.min(x_train), np.max(x_train)))
print("x_valid: {} | {:.2f} ~ {:.2f}".format(x_valid.shape, np.min(x_valid), np.max(x_valid)))

training_generator = generate_pair(x_train,conf.batch_size, conf)
validation_generator = generate_pair(x_valid,conf.batch_size, conf)


device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = unet3d.UNet3D()
model = nn.DataParallel(model, device_ids = [i for i in range(torch.cuda.device_count())])
model.to(device)

print("Total CUDA devices: ", torch.cuda.device_count())

summary(model, (1,conf.input_rows,conf.input_cols,conf.input_deps), batch_size=-1)
criterion = nn.MSELoss()

if conf.optimizer == "sgd":
	optimizer = torch.optim.SGD(model.parameters(), conf.lr, momentum=0.9, weight_decay=0.0, nesterov=False)
elif conf.optimizer == "adam":
	optimizer = torch.optim.Adam(model.parameters(), conf.lr)
else:
	raise
예제 #2
0
파일: main.py 프로젝트: JXQI/ModelsGenesis
        self.dense_1 = nn.Linear(512, 1024, bias=True)
        self.dense_2 = nn.Linear(1024, n_class, bias=True)

    def forward(self, x):
        self.base_model(x)
        self.base_out = self.base_model.out512
        self.out_glb_avg_pool = F.avg_pool3d(
            self.base_out,
            kernel_size=self.base_out.size()[2:]).view(self.base_out.size()[0],
                                                       -1)
        self.linear_out = self.dense_1(self.out_glb_avg_pool)
        final_out = self.dense_2(F.relu(self.linear_out))
        return final_out


base_model = unet3d.UNet3D()
# Load pre-trained weights
weight_dir = 'pretrained_weights/Genesis_Chest_CT.pt'
checkpoint = torch.load(weight_dir, map_location=torch.device('cpu'))
state_dict = checkpoint['state_dict']
unParalled_state_dict = {}
for key in state_dict.keys():
    unParalled_state_dict[key.replace("module.", "")] = state_dict[key]
base_model.load_state_dict(unParalled_state_dict)
target_model = TargetNet(base_model)
target_model.to(config.device)
target_model = nn.DataParallel(
    target_model, device_ids=[i for i in range(torch.cuda.device_count())])
criterion = nn.BCELoss()
optimizer = torch.optim.SGD(target_model.parameters(),
                            config.lr,
예제 #3
0
def replicate_model_genesis():

    os.environ["CUDA_VISIBLE_DEVICES"] = "0"

    conf = models_genesis_config()
    conf.display()

    # arrays of numpy arrays
    x_train = []
    for i, fold in enumerate((conf.train_fold)):
        file_name = "bat_" + str(conf.scale) + "_s_" + str(
            conf.input_rows) + "x" + str(conf.input_cols) + "x" + str(
                conf.input_deps) + "_" + str(fold) + ".npy"
        s = np.load(os.path.join(conf.data, file_name))
        x_train.extend(s)

    x_train = np.expand_dims(
        np.array(x_train),
        axis=1)  # (2848, 64, 64, 32) -> (2848, 1, 64, 64, 32)

    x_valid = []
    for i, fold in enumerate(tqdm(conf.valid_fold)):
        file_name = "bat_" + str(conf.scale) + "_s_" + str(
            conf.input_rows) + "x" + str(conf.input_cols) + "x" + str(
                conf.input_deps) + "_" + str(fold) + ".npy"
        s = np.load(os.path.join(conf.data, file_name))
        x_valid.extend(s)
    x_valid = np.expand_dims(np.array(x_valid), axis=1)

    print("x_train: {} | {:.2f} ~ {:.2f}".format(x_train.shape,
                                                 np.min(x_train),
                                                 np.max(x_train)))
    print("x_valid: {} | {:.2f} ~ {:.2f}".format(x_valid.shape,
                                                 np.min(x_valid),
                                                 np.max(x_valid)))

    # make x, y  for auto enconding with transformations
    training_generator = generate_pair(x_train, conf.batch_size, conf)
    validation_generator = generate_pair(x_valid, conf.batch_size, conf)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = unet3d.UNet3D()
    model = nn.DataParallel(
        model, device_ids=[i for i in range(torch.cuda.device_count())])
    model.to(device)

    print("Total CUDA devices: ", torch.cuda.device_count())

    # summary(model, (1,conf.input_rows,conf.input_cols,conf.input_deps), batch_size=-1)
    criterion = nn.MSELoss()

    if conf.optimizer == "sgd":
        optimizer = torch.optim.SGD(model.parameters(),
                                    conf.lr,
                                    momentum=0.9,
                                    weight_decay=0.0,
                                    nesterov=False)
    elif conf.optimizer == "adam":
        optimizer = torch.optim.Adam(model.parameters(), conf.lr)
    else:
        raise Exception

    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                step_size=int(conf.patience *
                                                              0.8),
                                                gamma=0.5)

    # to track the training loss as the model trains
    train_losses = []
    # to track the validation loss as the model trains
    valid_losses = []
    # to track the average training loss per epoch as the model trains
    avg_train_losses = []
    # to track the average validation loss per epoch as the model trains
    avg_valid_losses = []
    best_loss = 100000
    intial_epoch = 0
    num_epoch_no_improvement = 0
    sys.stdout.flush()

    # load model
    if conf.weights != None:
        checkpoint = torch.load(
            conf.weights, map_location=device)  #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
        model.load_state_dict(checkpoint["state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        intial_epoch = checkpoint["epoch"]
        print("Loading weights from ", conf.weights)
    sys.stdout.flush()

    for epoch in range(intial_epoch, conf.nb_epoch):
        scheduler.step(epoch)
        model.train()
        for iteration in range(int(x_train.shape[0] // conf.batch_size)):
            image, gt_prev = next(training_generator)  # (6, 1, 64, 64, 32)
            gt = np.repeat(gt, conf.nb_class, axis=1)  # (6, 1, 64, 64, 32)
            image, gt = torch.from_numpy(image).float().to(
                device), torch.from_numpy(gt).float().to(device)
            pred = model(image)
            loss = criterion(pred, gt)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_losses.append(round(loss.item(), 2))
            if (iteration + 1) % 5 == 0:
                print("Epoch [{}/{}], iteration {}, Loss: {:.6f}".format(
                    epoch + 1, conf.nb_epoch, iteration + 1,
                    np.average(train_losses)))
                sys.stdout.flush()

        with torch.no_grad():
            model.eval()
            print("validating....")
            for i in range(int(x_valid.shape[0] // conf.batch_size)):
                x, y = next(validation_generator)
                y = np.repeat(y, conf.nb_class, axis=1)
                image, gt = torch.from_numpy(x).float(), torch.from_numpy(
                    y).float()
                image = image.to(device)
                gt = gt.to(device)
                pred = model(image)
                loss = criterion(pred, gt)
                valid_losses.append(loss.item())

        # logging
        train_loss = np.average(train_losses)
        valid_loss = np.average(valid_losses)
        avg_train_losses.append(train_loss)
        avg_valid_losses.append(valid_loss)
        print("Epoch {}, validation loss is {:.4f}, training loss is {:.4f}".
              format(epoch + 1, valid_loss, train_loss))
        train_losses = []
        valid_losses = []
        if valid_loss < best_loss:
            print("Validation loss decreases from {:.4f} to {:.4f}".format(
                best_loss, valid_loss))
            best_loss = valid_loss
            num_epoch_no_improvement = 0
            # save model
            torch.save(
                {
                    "epoch": epoch + 1,
                    "state_dict": model.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict()
                }, os.path.join(conf.model_path, "Genesis_Chest_CT.pt"))
            print("Saving model ",
                  os.path.join(conf.model_path, "Genesis_Chest_CT.pt"))
        else:
            print(
                "Validation loss does not decrease from {:.4f}, num_epoch_no_improvement {}"
                .format(best_loss, num_epoch_no_improvement))
            num_epoch_no_improvement += 1
        if num_epoch_no_improvement == conf.patience:
            print("Early Stopping")
            break
        sys.stdout.flush()