예제 #1
0
def train_model(model: nn, dataset: Dataset,
                num_epochs: int = 10, specialize: bool = False) -> nn:

    loss_function = const.LOSS_FUNCTION
    optimizer = const.OPTIMIZER(model.parameters(), lr=const.LEARNING_RATE)

    for epoch in range(num_epochs):
        for i, song in enumerate(dataset):
            input_tensors, tag, output_tensors = song
            hidden = None if specialize else model.init_hidden()

            song_length = len(input_tensors)
            song_losses = []
            for t in range(1, song_length - 1, const.SEQ_LEN):
                roof = min(t + const.SEQ_LEN, song_length - 1)
                x_seq = input_tensors[t:roof]
                y_t = output_tensors[t:roof]

                model.zero_grad()
                output, hidden = model(x_seq, hidden, tag)

                loss = loss_function(output, y_t.view(roof-t, -1))
                song_losses.append(loss.item())
                loss.backward(retain_graph=True)
                optimizer.step()

            print("Epoch", epoch+1, "Song", i+1, "/", len(dataset))
            print("Avg loss for this song", sum(song_losses)/len(song_losses))

    return model
예제 #2
0
def test_class_accuracy(network: torch.nn, loader: torch.utils.data.DataLoader,
                        device: torch.device) -> float:
    """
    Test the class accuracy of a network on a dataset.

    :param network: network to test
    :param loader: loader to test
    :param device: device to use
    :return: result accuracy
    """
    network.eval()
    network.classify = True

    accuracy = 0
    classes = utils.datasets.get_vggface2_classes("train")

    for _, sample in enumerate(loader):
        inputs, labels = sample
        inputs = inputs.to(device)

        outputs = network(inputs)

        # Convert the class names with the class id and transpose the tensor.
        labels = [classes.index(label) for label in labels]
        labels = torch.LongTensor(labels).T
        labels = labels.to(device)

        output_labels = torch.topk(outputs, 1).indices.view(-1)
        accuracy += torch.count_nonzero(output_labels == labels)

    accuracy = accuracy / (len(loader.dataset) * 3)
    return accuracy
예제 #3
0
def train_fn(model: torch.nn, data_loader: DataLoader, optimizer: optim,
             device: torch.device, epoch: int):
    model.train()
    start_time = datetime.datetime.now()
    num_images: int = 0
    for i, (images, targets) in enumerate(data_loader):
        images = list(image.to(device) for image in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
        images = torch.stack(images)
        num_images += len(images)

        optimizer.zero_grad()
        loss_dict: Dict[str, torch.Tensor] = model(images, targets)
        loss: float = sum(loss for loss in loss_dict.values())

        loss.backward()
        optimizer.step()

        if (i + 1) % 10 == 0:
            print('-' * 50)
            print(
                f'Epoch {epoch+1}[{len(data_loader.dataset):,}/{(num_images/len(data_loader.dataset))*100:.2f}%] '
                f'- Elapsed time: {datetime.datetime.now() - start_time}\n'
                f' - loss: classifier={loss_dict["loss_classifier"]:.6f}, box_reg={loss_dict["loss_box_reg"]:.6f}, '
                f'objectness={loss_dict["loss_objectness"]:.6f}, rpn_box_reg={loss_dict["loss_rpn_box_reg"]:.6f}'
            )
예제 #4
0
 def save_checkpoint(self, epoch_score: float, model: torch.nn,
                     model_path: str):
     if epoch_score not in [-np.inf, np.inf, -np.nan, np.nan]:
         print(
             f'Validation score improved ({self.val_score:.6f} --> {epoch_score:.6f}). Saving model...\n'
         )
         torch.save(model.state_dict(), model_path)
     self.val_score = epoch_score
    def _get_cv_stats(self, model: torch.nn) -> Tuple[torch.tensor, ...]:
        """
        Collect CV data accuracy, loss, prediction distribution, and targets
        distribution.

        Arguments:
            model {torch.nn} -- model to learn

        Returns:
            Tuple[torch.tensor * 4] -- accuracy, loss, prediction distribution, and
                targets distribution.
        """
        cv_acc = torch.tensor(0.0).to(self.device)
        cv_loss = torch.tensor(0.0).to(self.device)
        samples_no = float(len(self.cvloader.dataset))
        outputs_dist = None
        targets_dist = None

        with torch.no_grad():
            model = model.eval()

            for inputs, targets in self.cvloader:
                batch_size = inputs.shape[
                    0]  # last sample can have different items

                targets = targets.to(self.device)
                outputs = model(inputs.to(self.device))

                if outputs_dist is None and targets_dist is None:
                    outputs_dist = outputs.argmax(1).long()
                    targets_dist = targets.long()
                else:
                    outputs_dist = torch.cat(
                        [outputs_dist, outputs.argmax(1).long()])
                    targets_dist = torch.cat([targets_dist, targets.long()])

                cv_acc += (outputs.argmax(1) == targets).sum()
                cv_loss += self.criterion(outputs, targets) * batch_size

            cv_acc = (cv_acc / samples_no).to(CPU_DEVICE)
            cv_loss = (cv_loss / samples_no).to(CPU_DEVICE)
            outputs_dist = outputs_dist.to(CPU_DEVICE)
            targets_dist = targets_dist.to(CPU_DEVICE)

        model = model.train()
        return cv_acc, cv_loss, outputs_dist, targets_dist
예제 #6
0
파일: test.py 프로젝트: unerue/competition
def valid_fn(
    model: torch.nn, 
    data_loader: DataLoader, 
    device: torch.device, 
    class_names: Dict[int, str], 
    valid: bool = True):
    """
    평가 및 결과 저장
    """
    xml_root = ET.Element('predictions')
    batch_size: int = data_loader.batch_size
    model.eval()
    with torch.set_grad_enabled(False):
        for i, (images, _) in tqdm(enumerate(data_loader)):
            images = list(image.to(device) for image in images)
            outputs = model(images)
            for j, output in enumerate(outputs):
                image_name: str = data_loader.dataset.images[i*batch_size+j]
                xml_image = ET.SubElement(xml_root, 'image', {'name': image_name})

                boxes = output['boxes'].detach().cpu().numpy()
                labels = output['labels'].detach().cpu().numpy()
                scores = output['scores'].detach().cpu().numpy()

                for box, label, score in zip(boxes, labels, scores):
                    attribs = {
                        'class_name': class_names[label],
                        'score': str(float(score)), 
                        'x1': str(int(box[0])), 
                        'y1': str(int(box[1])), 
                        'x2': str(int(box[2])), 
                        'y2': str(int(box[3]))
                    }
                    ET.SubElement(xml_image, 'predict', attribs)
                
    indent(xml_root)
    tree = ET.ElementTree(xml_root)
    if not os.path.exists('./output/'):
        os.mkdir('./output/')

    if valid:
        tree.write('./output/validation.xml')
    else:
        tree.write('./output/prediction.xml')
    print('Save predicted labels.xml...\n')
예제 #7
0
def init_weights(net: torch.nn,
                 init_type: str = "normal",
                 init_gain: float = 0.02) -> None:
    """Initialize network weights.
    Ref: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix

    Parameters:
    -----------
        net:         network to be initialized
        init_type:   the name of an initialization method
                     `normal` | `xavier` | `kaiming` | `orthogonal`
        init_gain:   scaling factor for normal, xavier and orthogonal.
    We use 'normal' in the original pix2pix and CycleGAN paper.
    But xavier and kaiming might work better for some applications.
    Feel free to try yourself.
    """
    def init_func(m):
        """Define the initialization function."""
        classname = m.__class__.__name__
        if hasattr(m, "weight") and (classname.find("Conv") != -1
                                     or classname.find("Linear") != -1):
            if init_type == "normal":
                init.normal_(m.weight.data, 0.0, init_gain)
            elif init_type == "xavier":
                init.xavier_normal_(m.weight.data, gain=init_gain)
            elif init_type == "kaiming":
                init.kaiming_normal_(m.weight.data, a=0, mode="fan_in")
            elif init_type == "orthogonal":
                init.orthogonal_(m.weight.data, gain=init_gain)
            else:
                raise NotImplementedError(
                    "initialization method [%s] is not implemented" %
                    init_type)
            if hasattr(m, "bias") and m.bias is not None:
                init.constant_(m.bias.data, 0.0)
        elif (classname.find("BatchNorm2d") !=
              -1):  # BatchNorm Layer's weight is not a matrix;
            # only normal distribution applies.
            init.normal_(m.weight.data, 1.0, init_gain)
            init.constant_(m.bias.data, 0.0)

    net.apply(init_func)  # apply the initialization function <init_func>
예제 #8
0
파일: test.py 프로젝트: unerue/competition
def test_fn(
    model: torch.nn, 
    data_loader: DataLoader, 
    class_nums: Dict,
    device: torch.device):
    image_ids = [image_id.split('.')[1][-17:] for image_id in data_loader.dataset.image_ids]
    xml_root = ET.Element('predictions')
    model.eval()
    batch_size = data_loader.batch_size
    with torch.no_grad():
        for i, (images, _) in tqdm(enumerate(data_loader)):
            images = list(image.to(device) for image in images)
            outputs = model(images)
            for j, output in enumerate(outputs):
                image_name = image_ids[i*batch_size+j]
                xml_image = ET.SubElement(xml_root, 'image', {'name': image_name})

                masks = output['masks'].detach().cpu().numpy()
                labels = output['labels'].detach().cpu().numpy()
                scores = output['scores'].detach().cpu().numpy()

                for mask, label, score in zip(masks, labels, scores):
                    mask_bin = np.where(mask[0] > 0.1, True, False)
                    polygons = Mask(mask_bin).polygons()
                    points = polygons.points
                    point = ''.join([str(p[0]) + ',' + str(p[1]) +';' for p in points[0]])
                    attribs = {
                        'class_name': class_nums[label],
                        'score': str(float(score)), 
                        'polygon': point,
                    }
                    ET.SubElement(xml_image, 'predict', attribs)
                
    indent(xml_root)
    tree = ET.ElementTree(xml_root)
    if not os.path.exists('./output/'):
        print('Not exists ./output/ making an weight folder...')
        os.mkdir('./output/')

    tree.write('./output/prediction.xml')
    print('Save predicted labels.xml...\n')
 def __init__(
         self,
         args,
         encoder: torch.
     nn,  # encoder for online network / target network is copy of online network
         predictor: torch.
     nn,  # predictor network comes after online network
         optimizer: torch.
     optim,  # cosine annealing learning rate after warm-up, wo/ restart
         **params):
     super(BYOL, self).__init__()
     self.args = args
     self.online_net = encoder.to(args.device)
     self.target_net = encoder.to(args.device)
     self.predictor = predictor.to(args.device)
     self.optimizer = optimizer
     self.max_epoch = args.max_epoch
     self.batch_size = args.batch_size
     self.num_workers = args.num_workers  # num_workers for data loader
     self.device = args.device  # set cuda device
     self.writer = SummaryWriter(log_dir=path_summary)
     self.resume = args.resume
예제 #10
0
def predict( model:torch.nn, final_activ:torch.nn.functional, dl, out_type, with_preds = False ):
  model.eval()
  for n, batch in enumerate( dl ):
    batch = [ nb.to( device = DEFAULT_DEVICE ) for nb in batch ]
    preds_batch = model( batch[0] )
    outputs_batch = apply_final_activ( \
      input = preds_batch, out_type = out_type, final_activ = final_activ, with_preds = with_preds
    )
    if with_preds: outputs_batch, preds_batch = outputs_batch
    if not n:
      if with_preds:
        preds = torch.empty( len( dl.dataset ), preds_batch.shape[1], \
                device = DEFAULT_DEVICE, dtype = preds_batch.dtype )
      outputs = torch.empty( len( dl.dataset ), \
                device = DEFAULT_DEVICE, dtype = outputs_batch.dtype )
    if with_preds:
      preds[ n*dl.batch_size : (n+1)*dl.batch_size, : ] = preds_batch
    outputs[ n*dl.batch_size : (n+1)*dl.batch_size ] = outputs_batch

  if with_preds: _re = ( outputs, preds )
  else: _re = output

  return _re
예제 #11
0
    def __init__(
        self,
        model: torch.nn.Module,
        loss_function: torch.nn,
        optimizer: torch.optim,
        epochs: int,
        model_info: list(),
        save_period: int,
        savedir: str,
        lr_scheduler: torch.optim.lr_scheduler = None,
        device: str = None,
    ):
        """
        Args:
            model (torch.nn.Module): The model to be trained
            loss_function (MultiLoss): The loss function or loss function class
            optimizer (torch.optim): torch.optim, i.e., the optimizer class
            config (dict): dict of configs
            lr_scheduler (torch.optim.lr_scheduler): pytorch lr_scheduler for manipulating the learning rate
            seed (int): integer seed to enforce non stochasticity,
            device (str): string of the device to be trained on, e.g., "cuda:0"
        """

        # Model to device
        self.device = torch.device(device)

        self.model = model.to(self.device)
        self.lr_scheduler = lr_scheduler
        self.loss_function = loss_function.to(self.device)
        self.optimizer = optimizer

        self.epochs = epochs
        self.model_info = model_info
        self.save_period = save_period
        self.start_epoch = 1

        self.checkpoint_dir = Path(savedir) / Path(
            datetime.today().strftime('%Y-%m-%d'))
        self.min_validation_loss = sys.float_info.max  # Minimum validation loss achieved, starting with the larges possible number
예제 #12
0
def update_training_pool_ids_2(net: torch.nn,
                               training_pool_ids_path: str,
                               all_training_data,
                               device: str,
                               acquisition_func: str = "cfe"):
    """
    training_pool_ids_path: the path to json file which contains images id in training pool.
    acquisition_func: string name of acquisition function:
                    available function: mutual_information, mean_first_entropy, category_first_entropy
    This function will use an acquisition function to collect new 100 imgs into training pool each phase.
        /Increase the json file 100 more imgs each phase.

    """
    batch_size = 1
    training_pool_data = get_pool_data(training_pool_ids_path)
    all_training_data = get_pool_data(all_training_data)
    active_pool = set(all_training_data) - set(training_pool_data)
    dataset = RestrictedDataset(dir_img, dir_mask, list(active_pool))
    pool_loader = DataLoader(dataset,
                             batch_size=1,
                             shuffle=True,
                             num_workers=4,
                             pin_memory=True)

    value = []
    imgs_id = []

    if acquisition_func == "cfe":
        evaluation_criteria = acquisition_function.category_first_entropy
    elif acquisition_func == "mfe":
        evaluation_criteria = acquisition_function.mean_first_entropy
    elif acquisition_func == "mi":
        evaluation_criteria = acquisition_function.mutual_information
    else:
        print("Error choosing acquisition function")
        evaluation_criteria = None

    net.eval()
    n_pool = len(dataset)
    with tqdm(total=n_pool, desc='STD calculating', unit='batch',
              leave=False) as pbar:
        for ind, batch in enumerate(tqdm(pool_loader)):
            imgs, true_masks = batch['image'], batch['mask']
            imgs = imgs.to(device=device, dtype=torch.float32)
            true_masks = true_masks.to(device=device,
                                       dtype=torch.float32)  # BHWC
            true_masks = true_masks[:, :1, :, :]

            _value = evaluation_criteria(GAUSS_ITERATION, net, imgs)
            _imgs_id = batch['id']
            for i in range(batch_size):
                if i >= len(_value):
                    continue
                value.extend(_value)
                imgs_id.extend(_imgs_id)
            pbar.update()

    value, imgs_id = zip(*sorted(zip(value, imgs_id)))  # order = ascending
    print("length of value/imgs_id: ", len(value), len(imgs_id))
    top_100img = imgs_id[-100:]  # the higher
    for i in top_100img:
        add_image_id_to_pool(i, training_pool_ids_path)
    print("Adding successfully!")
예제 #13
0
def load_frzn_model(
    model: torch.nn,
    path: str,
    current_args: Namespace = None,
    cuda: bool = None,
    logger: logging.Logger = None,
) -> MoleculeModel:
    """
    Loads a model checkpoint.
    :param path: Path where checkpoint is saved.
    :param current_args: The current arguments. Replaces the arguments loaded from the checkpoint if provided.
    :param cuda: Whether to move model to cuda.
    :param logger: A logger.
    :return: The loaded MoleculeModel.
    """
    debug = logger.debug if logger is not None else print

    loaded_mpnn_model = torch.load(path,
                                   map_location=lambda storage, loc: storage)
    loaded_state_dict = loaded_mpnn_model["state_dict"]
    loaded_args = loaded_mpnn_model["args"]

    model_state_dict = model.state_dict()

    if loaded_args.number_of_molecules == 1 and current_args.number_of_molecules == 1:
        encoder_param_names = [
            "encoder.encoder.0.W_i.weight",
            "encoder.encoder.0.W_h.weight",
            "encoder.encoder.0.W_o.weight",
            "encoder.encoder.0.W_o.bias",
        ]
        if current_args.checkpoint_frzn is not None:
            # Freeze the MPNN
            for param_name in encoder_param_names:
                model_state_dict = overwrite_state_dict(
                    param_name, param_name, loaded_state_dict,
                    model_state_dict)

        if current_args.frzn_ffn_layers > 0:
            ffn_param_names = [[f"ffn.{i*3+1}.weight", f"ffn.{i*3+1}.bias"]
                               for i in range(current_args.frzn_ffn_layers)]
            ffn_param_names = [
                item for sublist in ffn_param_names for item in sublist
            ]

            # Freeze MPNN and FFN layers
            for param_name in encoder_param_names + ffn_param_names:
                model_state_dict = overwrite_state_dict(
                    param_name, param_name, loaded_state_dict,
                    model_state_dict)

        if current_args.freeze_first_only:
            debug(
                "WARNING: --freeze_first_only flag cannot be used with number_of_molecules=1 (flag is ignored)"
            )

    elif loaded_args.number_of_molecules == 1 and current_args.number_of_molecules > 1:
        # TODO(degraff): these two `if`-blocks can be condensed into one
        if (current_args.checkpoint_frzn is not None
                and current_args.freeze_first_only and
                current_args.frzn_ffn_layers <= 0):  # Only freeze first MPNN
            encoder_param_names = [
                "encoder.encoder.0.W_i.weight",
                "encoder.encoder.0.W_h.weight",
                "encoder.encoder.0.W_o.weight",
                "encoder.encoder.0.W_o.bias",
            ]
            for param_name in encoder_param_names:
                model_state_dict = overwrite_state_dict(
                    param_name, param_name, loaded_state_dict,
                    model_state_dict)
        if (
                current_args.checkpoint_frzn is not None
                and not current_args.freeze_first_only
                and current_args.frzn_ffn_layers <= 0
        ):  # Duplicate encoder from frozen checkpoint and overwrite all encoders
            loaded_encoder_param_names = [
                "encoder.encoder.0.W_i.weight",
                "encoder.encoder.0.W_h.weight",
                "encoder.encoder.0.W_o.weight",
                "encoder.encoder.0.W_o.bias",
            ] * current_args.number_of_molecules

            model_encoder_param_names = [[(
                f"encoder.encoder.{mol_num}.W_i.weight",
                f"encoder.encoder.{mol_num}.W_h.weight",
                f"encoder.encoder.{mol_num}.W_o.weight",
                f"encoder.encoder.{mol_num}.W_o.bias",
            )] for mol_num in range(current_args.number_of_molecules)]
            model_encoder_param_names = [
                item for sublist in model_encoder_param_names
                for item in sublist
            ]

            for loaded_param_name, model_param_name in zip(
                    loaded_encoder_param_names, model_encoder_param_names):
                model_state_dict = overwrite_state_dict(
                    loaded_param_name, model_param_name, loaded_state_dict,
                    model_state_dict)

        if current_args.frzn_ffn_layers > 0:
            raise ValueError(
                f"Number of molecules from checkpoint_frzn ({loaded_args.number_of_molecules}) "
                f"must equal current number of molecules ({current_args.number_of_molecules})!"
            )

    elif loaded_args.number_of_molecules > 1 and current_args.number_of_molecules > 1:
        if loaded_args.number_of_molecules != current_args.number_of_molecules:
            raise ValueError(
                f"Number of molecules in checkpoint_frzn ({loaded_args.number_of_molecules}) "
                f"must either match current model ({current_args.number_of_molecules}) or equal 1."
            )

        if current_args.freeze_first_only:
            raise ValueError(
                f"Number of molecules in checkpoint_frzn ({loaded_args.number_of_molecules}) "
                "must be equal to 1 for freeze_first_only to be used!")

        if (current_args.checkpoint_frzn
                is not None) & (not (current_args.frzn_ffn_layers > 0)):
            encoder_param_names = [[(
                f"encoder.encoder.{mol_num}.W_i.weight",
                f"encoder.encoder.{mol_num}.W_h.weight",
                f"encoder.encoder.{mol_num}.W_o.weight",
                f"encoder.encoder.{mol_num}.W_o.bias",
            )] for mol_num in range(current_args.number_of_molecules)]
            encoder_param_names = [
                item for sublist in encoder_param_names for item in sublist
            ]

            for param_name in encoder_param_names:
                model_state_dict = overwrite_state_dict(
                    param_name, param_name, loaded_state_dict,
                    model_state_dict)

        if current_args.frzn_ffn_layers > 0:
            encoder_param_names = [[(
                f"encoder.encoder.{mol_num}.W_i.weight",
                f"encoder.encoder.{mol_num}.W_h.weight",
                f"encoder.encoder.{mol_num}.W_o.weight",
                f"encoder.encoder.{mol_num}.W_o.bias",
            )] for mol_num in range(current_args.number_of_molecules)]
            encoder_param_names = [
                item for sublist in encoder_param_names for item in sublist
            ]
            ffn_param_names = [[f"ffn.{i+3+1}.weight", f"ffn.{i+3+1}.bias"]
                               for i in range(current_args.frzn_ffn_layers)]
            ffn_param_names = [
                item for sublist in ffn_param_names for item in sublist
            ]

            for param_name in encoder_param_names + ffn_param_names:
                model_state_dict = overwrite_state_dict(
                    param_name, param_name, loaded_state_dict,
                    model_state_dict)

        if current_args.frzn_ffn_layers >= current_args.ffn_num_layers:
            raise ValueError(
                f"Number of frozen FFN layers ({current_args.frzn_ffn_layers}) "
                f"must be less than the number of FFN layers ({current_args.ffn_num_layers})!"
            )

    # Load pretrained weights
    model.load_state_dict(model_state_dict)

    return model
def generate_sample_radial(
    model: torch.nn,
    features: int = 10,
    targets: int = 1,
    iterations: int = 10,
    start_image: Tensor = None,
    image_size: int = 64,
    return_intermediate: bool = True,
    device: str = "cpu",
    rotations: int = 2,
    batch_size: int = 256,
    all_random: bool = True,
) -> List[Tensor]:
    """
    Create a sample either generated on top of a starting image
    or from a random image. Return a series of images with the
    for each time the network is applied.
    Args :
      model : The nn model
      features : The number of points used for interpolation
      targets : The number of target points where RGB will be predicted
      iterations : Number of times to apply the network to the image
      start_image : An normalized image to use as the initial value in range [-1, 1]
      image_size : The width of the image (assumed square)
      return_intermediate : Return intermediate values if true
      device : The device to perform operations on
      rotations : The number of rotations used by the network
      batch_size : Set the batch size to run through the network
    Returns :
      A list of images for each time the network is applied
    """

    num_pixels = image_size * image_size
    model.eval()

    if start_image is None:
        image = torch.rand([3, image_size, image_size], device=device) * 2 - 1
    else:
        image = copy.deepcopy(start_image)

    stripe_list = positions_from_mesh(
        width=image_size,
        height=image_size,
        device=device,
        rotations=rotations,
        normalize=True,
    )

    # These need to be normalized otherwise the max is the width of the grid
    stripe_list = torch.cat([val.unsqueeze(0) for val in stripe_list])

    indices, target_linear_indices = indices_from_grid(image_size, device=device)

    result_list = []
    for count in range(iterations):
        logger.info(f"Generating for count {count}")

        if all_random is True and start_image is None:
            image = torch.rand([3, image_size, image_size], device=device) * 2 - 1
        elif all_random is True and start_image is not None:
            image = copy.deepcopy(start_image)

        features_tensor, targets_tensor = random_radial_samples_from_image(
            img=image,
            stripe_list=stripe_list,
            image_size=image_size,
            feature_pixels=features,
            indices=indices,
            target_linear_indices=target_linear_indices,
            device=device,
        )

        result = model(features_tensor.flatten(1))
        image = result.reshape(image_size, image_size, 3).permute(2, 0, 1)

        result_list.append(image)

    return result_list
def generate_sample(
    model: torch.nn,
    features: int = 10,
    targets: int = 1,
    iterations: int = 10,
    image: Tensor = None,
    image_size: int = 64,
    return_intermediate: bool = True,
    device: str = "cpu",
    rotations: int = 2,
    batch_size: int = 256,
    all_random: bool = True,
) -> List[Tensor]:
    """
    Create a sample either generated on top of a starting image
    or from a random image. Return a series of images with the
    for each time the network is applied.
    Args :
      model : The nn model
      features : The number of points used for interpolation
      targets : The number of target points where RGB will be predicted
      iterations : Number of times to apply the network to the image
      image : An unnormalized image to use as the initial value
      image_size : The width of the image (assumed square)
      return_intermediate : Return intermediate values if true
      device : The device to perform operations on
      rotations : The number of rotations used by the network
      batch_size : Set the batch size to run through the network
    Returns :
      A list of images for each time the network is applied
    """

    num_pixels = image_size * image_size
    model.eval()

    if image is None:
        image = torch.rand([3, image_size, image_size], device=device) * 2 - 1
    else:
        image = (image / 255) * 2 - 1

    stripe_list = positions_from_mesh(
        width=image_size,
        height=image_size,
        device=device,
        rotations=rotations,
        normalize=True,
    )

    # These need to be normalized otherwise the max is the width of the grid
    new_vals = torch.cat([val.unsqueeze(0) for val in stripe_list])

    result_list = []
    for count in range(iterations):
        logger.info(f"Generating for count {count}")

        if all_random is True:
            image = torch.rand([3, image_size, image_size], device=device) * 2 - 1

        full_features = torch.cat([image, new_vals])
        channels, h, w = full_features.shape
        full_features = full_features.reshape(channels, -1).permute(1, 0)

        feature_indices = torch.remainder(
            torch.randperm(num_pixels * features, device=device), num_pixels
        )
        target_indices = torch.arange(start=0, end=num_pixels, device=device)

        features_tensor = full_features[feature_indices].reshape(-1, features, channels)
        targets_tensor = full_features[target_indices].reshape(-1, 1, channels)

        # Distances are measured relative to target so remove that component
        features_tensor[:, :, 3:] = features_tensor[:, :, 3:] - targets_tensor[:, :, 3:]

        result = model(features_tensor.flatten(1))
        image = result.reshape(image_size, image_size, 3).permute(2, 0, 1)

        result_list.append(image)

    return result_list
예제 #16
0
def load_frzn_model(model: torch.nn,
                    path: str,
                    current_args: Namespace = None,
                    cuda: bool = None,
                    logger: logging.Logger = None) -> MoleculeModel:
    """
    Loads a model checkpoint.
    :param path: Path where checkpoint is saved.
    :param current_args: The current arguments. Replaces the arguments loaded from the checkpoint if provided.
    :param cuda: Whether to move model to cuda.
    :param logger: A logger.
    :return: The loaded MoleculeModel.
    """
    debug = logger.debug if logger is not None else print

    loaded_mpnn_model = torch.load(path,
                                   map_location=lambda storage, loc: storage)
    loaded_state_dict = loaded_mpnn_model['state_dict']
    loaded_args = loaded_mpnn_model['args']

    model_state_dict = model.state_dict()

    if loaded_args.number_of_molecules == 1 & current_args.number_of_molecules == 1:
        encoder_param_names = [
            'encoder.encoder.0.W_i.weight', 'encoder.encoder.0.W_h.weight',
            'encoder.encoder.0.W_o.weight', 'encoder.encoder.0.W_o.bias'
        ]
        if current_args.checkpoint_frzn is not None:
            # Freeze the MPNN
            for param_name in encoder_param_names:
                model_state_dict = overwrite_state_dict(
                    param_name, param_name, loaded_state_dict,
                    model_state_dict)

        if current_args.frzn_ffn_layers > 0:
            ffn_param_names = [[
                'ffn.' + str(i * 3 + 1) + '.weight',
                'ffn.' + str(i * 3 + 1) + '.bias'
            ] for i in range(current_args.frzn_ffn_layers)]
            ffn_param_names = [
                item for sublist in ffn_param_names for item in sublist
            ]

            # Freeze MPNN and FFN layers
            for param_name in encoder_param_names + ffn_param_names:
                model_state_dict = overwrite_state_dict(
                    param_name, param_name, loaded_state_dict,
                    model_state_dict)

        if current_args.freeze_first_only:
            debug(
                f'WARNING: --freeze_first_only flag cannot be used with number_of_molecules=1 (flag is ignored)'
            )

    elif (loaded_args.number_of_molecules
          == 1) & (current_args.number_of_molecules > 1):

        if (current_args.checkpoint_frzn is not None) & (
                current_args.freeze_first_only
        ) & (not (current_args.frzn_ffn_layers > 0)):  # Only freeze first MPNN
            encoder_param_names = [
                'encoder.encoder.0.W_i.weight', 'encoder.encoder.0.W_h.weight',
                'encoder.encoder.0.W_o.weight', 'encoder.encoder.0.W_o.bias'
            ]
            for param_name in encoder_param_names:
                model_state_dict = overwrite_state_dict(
                    param_name, param_name, loaded_state_dict,
                    model_state_dict)

        if (current_args.checkpoint_frzn is not None) & (
                not current_args.freeze_first_only
        ) & (
                not (current_args.frzn_ffn_layers > 0)
        ):  # Duplicate encoder from frozen checkpoint and overwrite all encoders
            loaded_encoder_param_names = [
                'encoder.encoder.0.W_i.weight', 'encoder.encoder.0.W_h.weight',
                'encoder.encoder.0.W_o.weight', 'encoder.encoder.0.W_o.bias'
            ] * current_args.number_of_molecules
            model_encoder_param_names = [[
                'encoder.encoder.' + str(mol_num) + '.W_i.weight',
                'encoder.encoder.' + str(mol_num) + '.W_h.weight',
                'encoder.encoder.' + str(mol_num) + '.W_o.weight',
                'encoder.encoder.' + str(mol_num) + '.W_o.bias'
            ] for mol_num in range(current_args.number_of_molecules)]
            model_encoder_param_names = [
                item for sublist in model_encoder_param_names
                for item in sublist
            ]
            for loaded_param_name, model_param_name in zip(
                    loaded_encoder_param_names, model_encoder_param_names):
                model_state_dict = overwrite_state_dict(
                    loaded_param_name, model_param_name, loaded_state_dict,
                    model_state_dict)

        if current_args.frzn_ffn_layers > 0:  # Duplicate encoder from frozen checkpoint and overwrite all encoders + FFN layers
            raise Exception(
                'Number of molecules in checkpoint_frzn must be equal to current model for ffn layers to be frozen'
            )

    elif (loaded_args.number_of_molecules >
          1) & (current_args.number_of_molecules > 1):
        if (loaded_args.number_of_molecules) != (
                current_args.number_of_molecules):
            raise Exception(
                'Number of molecules in checkpoint_frzn ({}) must match current model ({}) OR equal to 1.'
                .format(loaded_args.number_of_molecules,
                        current_args.number_of_molecules))

        if current_args.freeze_first_only:
            raise Exception(
                'Number of molecules in checkpoint_frzn ({}) must be equal to 1 for freeze_first_only to be used.'
                .format(loaded_args.number_of_molecules))

        if (current_args.checkpoint_frzn
                is not None) & (not (current_args.frzn_ffn_layers > 0)):
            encoder_param_names = [[
                'encoder.encoder.' + str(mol_num) + '.W_i.weight',
                'encoder.encoder.' + str(mol_num) + '.W_h.weight',
                'encoder.encoder.' + str(mol_num) + '.W_o.weight',
                'encoder.encoder.' + str(mol_num) + '.W_o.bias'
            ] for mol_num in range(current_args.number_of_molecules)]
            encoder_param_names = [
                item for sublist in encoder_param_names for item in sublist
            ]

            for param_name in encoder_param_names:
                model_state_dict = overwrite_state_dict(
                    param_name, param_name, loaded_state_dict,
                    model_state_dict)

        if current_args.frzn_ffn_layers > 0:

            encoder_param_names = [[
                'encoder.encoder.' + str(mol_num) + '.W_i.weight',
                'encoder.encoder.' + str(mol_num) + '.W_h.weight',
                'encoder.encoder.' + str(mol_num) + '.W_o.weight',
                'encoder.encoder.' + str(mol_num) + '.W_o.bias'
            ] for mol_num in range(current_args.number_of_molecules)]
            encoder_param_names = [
                item for sublist in encoder_param_names for item in sublist
            ]
            ffn_param_names = [[
                'ffn.' + str(i * 3 + 1) + '.weight',
                'ffn.' + str(i * 3 + 1) + '.bias'
            ] for i in range(current_args.frzn_ffn_layers)]
            ffn_param_names = [
                item for sublist in ffn_param_names for item in sublist
            ]

            for param_name in encoder_param_names + ffn_param_names:
                model_state_dict = overwrite_state_dict(
                    param_name, param_name, loaded_state_dict,
                    model_state_dict)

        if current_args.frzn_ffn_layers >= current_args.ffn_num_layers:
            raise Exception(
                'Number of frozen FFN layers must be less than the number of FFN layers'
            )

    # Load pretrained weights
    model.load_state_dict(model_state_dict)

    return model
예제 #17
0
 def __init__(self, module: torch.nn):
     self.hook = module.register_forward_hook(self.hook_fn)
예제 #18
0
def update_training_pool_ids(net: torch.nn, training_pool_ids_path: str,
                             all_training_data, device: str):
    """
    training_pool_ids_path: the path to json file which contains images id in training pool.
    This function will use an acquisition function to collect new 100 imgs into training pool each phase.
        /Increase the json file 100 more imgs each phase.
    """
    batch_size = 1
    training_pool_data = get_pool_data(training_pool_ids_path)
    all_training_data = get_pool_data(all_training_data)
    active_pool = set(all_training_data) - set(training_pool_data)
    dataset = RestrictedDataset(dir_img, dir_mask, list(active_pool))
    pool_loader = DataLoader(dataset,
                             batch_size=1,
                             shuffle=True,
                             num_workers=4,
                             pin_memory=True)

    std = []
    imgs_id = []

    net.eval()
    n_pool = len(dataset)
    with tqdm(total=n_pool, desc='STD calculating', unit='batch',
              leave=False) as pbar:
        for ind, batch in enumerate(tqdm(pool_loader)):
            imgs, true_masks = batch['image'], batch['mask']
            imgs = imgs.to(device=device, dtype=torch.float32)
            true_masks = true_masks.to(device=device,
                                       dtype=torch.float32)  # BHWC
            true_masks = true_masks[:, :1, :, :]

            y_pred_samples = []
            for i in range(GAUSS_ITERATION):
                with torch.no_grad():
                    logits = net(imgs)
                    y_pred = torch.sigmoid(logits)
                    # y_pred = (y_pred > 0.5).float()
                    y_pred = y_pred[:, :1, :, :]
                    y_pred_samples.append(
                        y_pred[:, 0, :, :]
                    )  # y_pred_samples's shape: (inx, bat, H, W )
            y_pred_samples = torch.stack(y_pred_samples, dim=0)
            y_pred_samples = y_pred_samples.type(torch.FloatTensor)
            mean_y_pred = y_pred_samples.mean(dim=0)  # shape: batch, H, W
            std_y_pred = y_pred_samples.std(dim=0)  # shape: batch, H, W
            grid = torchvision.utils.make_grid(mean_y_pred.unsqueeze(1))
            _std = get_segmentation_mask_uncertainty(std_y_pred)
            _imgs_id = batch['id']
            for i in range(batch_size):
                if i >= len(_std):
                    continue
                std.extend(_std)
                imgs_id.extend(_imgs_id)
            pbar.update()

    std, imgs_id = zip(*sorted(zip(std, imgs_id)))  # order = ascending
    print("length of std/imgs_id: ", len(std), len(imgs_id))
    top_100img = imgs_id[-100:]
    for i in top_100img:
        add_image_id_to_pool(i, training_pool_ids_path)
    print("Adding successfully!")