def train_model(model: nn, dataset: Dataset, num_epochs: int = 10, specialize: bool = False) -> nn: loss_function = const.LOSS_FUNCTION optimizer = const.OPTIMIZER(model.parameters(), lr=const.LEARNING_RATE) for epoch in range(num_epochs): for i, song in enumerate(dataset): input_tensors, tag, output_tensors = song hidden = None if specialize else model.init_hidden() song_length = len(input_tensors) song_losses = [] for t in range(1, song_length - 1, const.SEQ_LEN): roof = min(t + const.SEQ_LEN, song_length - 1) x_seq = input_tensors[t:roof] y_t = output_tensors[t:roof] model.zero_grad() output, hidden = model(x_seq, hidden, tag) loss = loss_function(output, y_t.view(roof-t, -1)) song_losses.append(loss.item()) loss.backward(retain_graph=True) optimizer.step() print("Epoch", epoch+1, "Song", i+1, "/", len(dataset)) print("Avg loss for this song", sum(song_losses)/len(song_losses)) return model
def test_class_accuracy(network: torch.nn, loader: torch.utils.data.DataLoader, device: torch.device) -> float: """ Test the class accuracy of a network on a dataset. :param network: network to test :param loader: loader to test :param device: device to use :return: result accuracy """ network.eval() network.classify = True accuracy = 0 classes = utils.datasets.get_vggface2_classes("train") for _, sample in enumerate(loader): inputs, labels = sample inputs = inputs.to(device) outputs = network(inputs) # Convert the class names with the class id and transpose the tensor. labels = [classes.index(label) for label in labels] labels = torch.LongTensor(labels).T labels = labels.to(device) output_labels = torch.topk(outputs, 1).indices.view(-1) accuracy += torch.count_nonzero(output_labels == labels) accuracy = accuracy / (len(loader.dataset) * 3) return accuracy
def train_fn(model: torch.nn, data_loader: DataLoader, optimizer: optim, device: torch.device, epoch: int): model.train() start_time = datetime.datetime.now() num_images: int = 0 for i, (images, targets) in enumerate(data_loader): images = list(image.to(device) for image in images) targets = [{k: v.to(device) for k, v in t.items()} for t in targets] images = torch.stack(images) num_images += len(images) optimizer.zero_grad() loss_dict: Dict[str, torch.Tensor] = model(images, targets) loss: float = sum(loss for loss in loss_dict.values()) loss.backward() optimizer.step() if (i + 1) % 10 == 0: print('-' * 50) print( f'Epoch {epoch+1}[{len(data_loader.dataset):,}/{(num_images/len(data_loader.dataset))*100:.2f}%] ' f'- Elapsed time: {datetime.datetime.now() - start_time}\n' f' - loss: classifier={loss_dict["loss_classifier"]:.6f}, box_reg={loss_dict["loss_box_reg"]:.6f}, ' f'objectness={loss_dict["loss_objectness"]:.6f}, rpn_box_reg={loss_dict["loss_rpn_box_reg"]:.6f}' )
def save_checkpoint(self, epoch_score: float, model: torch.nn, model_path: str): if epoch_score not in [-np.inf, np.inf, -np.nan, np.nan]: print( f'Validation score improved ({self.val_score:.6f} --> {epoch_score:.6f}). Saving model...\n' ) torch.save(model.state_dict(), model_path) self.val_score = epoch_score
def _get_cv_stats(self, model: torch.nn) -> Tuple[torch.tensor, ...]: """ Collect CV data accuracy, loss, prediction distribution, and targets distribution. Arguments: model {torch.nn} -- model to learn Returns: Tuple[torch.tensor * 4] -- accuracy, loss, prediction distribution, and targets distribution. """ cv_acc = torch.tensor(0.0).to(self.device) cv_loss = torch.tensor(0.0).to(self.device) samples_no = float(len(self.cvloader.dataset)) outputs_dist = None targets_dist = None with torch.no_grad(): model = model.eval() for inputs, targets in self.cvloader: batch_size = inputs.shape[ 0] # last sample can have different items targets = targets.to(self.device) outputs = model(inputs.to(self.device)) if outputs_dist is None and targets_dist is None: outputs_dist = outputs.argmax(1).long() targets_dist = targets.long() else: outputs_dist = torch.cat( [outputs_dist, outputs.argmax(1).long()]) targets_dist = torch.cat([targets_dist, targets.long()]) cv_acc += (outputs.argmax(1) == targets).sum() cv_loss += self.criterion(outputs, targets) * batch_size cv_acc = (cv_acc / samples_no).to(CPU_DEVICE) cv_loss = (cv_loss / samples_no).to(CPU_DEVICE) outputs_dist = outputs_dist.to(CPU_DEVICE) targets_dist = targets_dist.to(CPU_DEVICE) model = model.train() return cv_acc, cv_loss, outputs_dist, targets_dist
def valid_fn( model: torch.nn, data_loader: DataLoader, device: torch.device, class_names: Dict[int, str], valid: bool = True): """ 평가 및 결과 저장 """ xml_root = ET.Element('predictions') batch_size: int = data_loader.batch_size model.eval() with torch.set_grad_enabled(False): for i, (images, _) in tqdm(enumerate(data_loader)): images = list(image.to(device) for image in images) outputs = model(images) for j, output in enumerate(outputs): image_name: str = data_loader.dataset.images[i*batch_size+j] xml_image = ET.SubElement(xml_root, 'image', {'name': image_name}) boxes = output['boxes'].detach().cpu().numpy() labels = output['labels'].detach().cpu().numpy() scores = output['scores'].detach().cpu().numpy() for box, label, score in zip(boxes, labels, scores): attribs = { 'class_name': class_names[label], 'score': str(float(score)), 'x1': str(int(box[0])), 'y1': str(int(box[1])), 'x2': str(int(box[2])), 'y2': str(int(box[3])) } ET.SubElement(xml_image, 'predict', attribs) indent(xml_root) tree = ET.ElementTree(xml_root) if not os.path.exists('./output/'): os.mkdir('./output/') if valid: tree.write('./output/validation.xml') else: tree.write('./output/prediction.xml') print('Save predicted labels.xml...\n')
def init_weights(net: torch.nn, init_type: str = "normal", init_gain: float = 0.02) -> None: """Initialize network weights. Ref: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix Parameters: ----------- net: network to be initialized init_type: the name of an initialization method `normal` | `xavier` | `kaiming` | `orthogonal` init_gain: scaling factor for normal, xavier and orthogonal. We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might work better for some applications. Feel free to try yourself. """ def init_func(m): """Define the initialization function.""" classname = m.__class__.__name__ if hasattr(m, "weight") and (classname.find("Conv") != -1 or classname.find("Linear") != -1): if init_type == "normal": init.normal_(m.weight.data, 0.0, init_gain) elif init_type == "xavier": init.xavier_normal_(m.weight.data, gain=init_gain) elif init_type == "kaiming": init.kaiming_normal_(m.weight.data, a=0, mode="fan_in") elif init_type == "orthogonal": init.orthogonal_(m.weight.data, gain=init_gain) else: raise NotImplementedError( "initialization method [%s] is not implemented" % init_type) if hasattr(m, "bias") and m.bias is not None: init.constant_(m.bias.data, 0.0) elif (classname.find("BatchNorm2d") != -1): # BatchNorm Layer's weight is not a matrix; # only normal distribution applies. init.normal_(m.weight.data, 1.0, init_gain) init.constant_(m.bias.data, 0.0) net.apply(init_func) # apply the initialization function <init_func>
def test_fn( model: torch.nn, data_loader: DataLoader, class_nums: Dict, device: torch.device): image_ids = [image_id.split('.')[1][-17:] for image_id in data_loader.dataset.image_ids] xml_root = ET.Element('predictions') model.eval() batch_size = data_loader.batch_size with torch.no_grad(): for i, (images, _) in tqdm(enumerate(data_loader)): images = list(image.to(device) for image in images) outputs = model(images) for j, output in enumerate(outputs): image_name = image_ids[i*batch_size+j] xml_image = ET.SubElement(xml_root, 'image', {'name': image_name}) masks = output['masks'].detach().cpu().numpy() labels = output['labels'].detach().cpu().numpy() scores = output['scores'].detach().cpu().numpy() for mask, label, score in zip(masks, labels, scores): mask_bin = np.where(mask[0] > 0.1, True, False) polygons = Mask(mask_bin).polygons() points = polygons.points point = ''.join([str(p[0]) + ',' + str(p[1]) +';' for p in points[0]]) attribs = { 'class_name': class_nums[label], 'score': str(float(score)), 'polygon': point, } ET.SubElement(xml_image, 'predict', attribs) indent(xml_root) tree = ET.ElementTree(xml_root) if not os.path.exists('./output/'): print('Not exists ./output/ making an weight folder...') os.mkdir('./output/') tree.write('./output/prediction.xml') print('Save predicted labels.xml...\n')
def __init__( self, args, encoder: torch. nn, # encoder for online network / target network is copy of online network predictor: torch. nn, # predictor network comes after online network optimizer: torch. optim, # cosine annealing learning rate after warm-up, wo/ restart **params): super(BYOL, self).__init__() self.args = args self.online_net = encoder.to(args.device) self.target_net = encoder.to(args.device) self.predictor = predictor.to(args.device) self.optimizer = optimizer self.max_epoch = args.max_epoch self.batch_size = args.batch_size self.num_workers = args.num_workers # num_workers for data loader self.device = args.device # set cuda device self.writer = SummaryWriter(log_dir=path_summary) self.resume = args.resume
def predict( model:torch.nn, final_activ:torch.nn.functional, dl, out_type, with_preds = False ): model.eval() for n, batch in enumerate( dl ): batch = [ nb.to( device = DEFAULT_DEVICE ) for nb in batch ] preds_batch = model( batch[0] ) outputs_batch = apply_final_activ( \ input = preds_batch, out_type = out_type, final_activ = final_activ, with_preds = with_preds ) if with_preds: outputs_batch, preds_batch = outputs_batch if not n: if with_preds: preds = torch.empty( len( dl.dataset ), preds_batch.shape[1], \ device = DEFAULT_DEVICE, dtype = preds_batch.dtype ) outputs = torch.empty( len( dl.dataset ), \ device = DEFAULT_DEVICE, dtype = outputs_batch.dtype ) if with_preds: preds[ n*dl.batch_size : (n+1)*dl.batch_size, : ] = preds_batch outputs[ n*dl.batch_size : (n+1)*dl.batch_size ] = outputs_batch if with_preds: _re = ( outputs, preds ) else: _re = output return _re
def __init__( self, model: torch.nn.Module, loss_function: torch.nn, optimizer: torch.optim, epochs: int, model_info: list(), save_period: int, savedir: str, lr_scheduler: torch.optim.lr_scheduler = None, device: str = None, ): """ Args: model (torch.nn.Module): The model to be trained loss_function (MultiLoss): The loss function or loss function class optimizer (torch.optim): torch.optim, i.e., the optimizer class config (dict): dict of configs lr_scheduler (torch.optim.lr_scheduler): pytorch lr_scheduler for manipulating the learning rate seed (int): integer seed to enforce non stochasticity, device (str): string of the device to be trained on, e.g., "cuda:0" """ # Model to device self.device = torch.device(device) self.model = model.to(self.device) self.lr_scheduler = lr_scheduler self.loss_function = loss_function.to(self.device) self.optimizer = optimizer self.epochs = epochs self.model_info = model_info self.save_period = save_period self.start_epoch = 1 self.checkpoint_dir = Path(savedir) / Path( datetime.today().strftime('%Y-%m-%d')) self.min_validation_loss = sys.float_info.max # Minimum validation loss achieved, starting with the larges possible number
def update_training_pool_ids_2(net: torch.nn, training_pool_ids_path: str, all_training_data, device: str, acquisition_func: str = "cfe"): """ training_pool_ids_path: the path to json file which contains images id in training pool. acquisition_func: string name of acquisition function: available function: mutual_information, mean_first_entropy, category_first_entropy This function will use an acquisition function to collect new 100 imgs into training pool each phase. /Increase the json file 100 more imgs each phase. """ batch_size = 1 training_pool_data = get_pool_data(training_pool_ids_path) all_training_data = get_pool_data(all_training_data) active_pool = set(all_training_data) - set(training_pool_data) dataset = RestrictedDataset(dir_img, dir_mask, list(active_pool)) pool_loader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=4, pin_memory=True) value = [] imgs_id = [] if acquisition_func == "cfe": evaluation_criteria = acquisition_function.category_first_entropy elif acquisition_func == "mfe": evaluation_criteria = acquisition_function.mean_first_entropy elif acquisition_func == "mi": evaluation_criteria = acquisition_function.mutual_information else: print("Error choosing acquisition function") evaluation_criteria = None net.eval() n_pool = len(dataset) with tqdm(total=n_pool, desc='STD calculating', unit='batch', leave=False) as pbar: for ind, batch in enumerate(tqdm(pool_loader)): imgs, true_masks = batch['image'], batch['mask'] imgs = imgs.to(device=device, dtype=torch.float32) true_masks = true_masks.to(device=device, dtype=torch.float32) # BHWC true_masks = true_masks[:, :1, :, :] _value = evaluation_criteria(GAUSS_ITERATION, net, imgs) _imgs_id = batch['id'] for i in range(batch_size): if i >= len(_value): continue value.extend(_value) imgs_id.extend(_imgs_id) pbar.update() value, imgs_id = zip(*sorted(zip(value, imgs_id))) # order = ascending print("length of value/imgs_id: ", len(value), len(imgs_id)) top_100img = imgs_id[-100:] # the higher for i in top_100img: add_image_id_to_pool(i, training_pool_ids_path) print("Adding successfully!")
def load_frzn_model( model: torch.nn, path: str, current_args: Namespace = None, cuda: bool = None, logger: logging.Logger = None, ) -> MoleculeModel: """ Loads a model checkpoint. :param path: Path where checkpoint is saved. :param current_args: The current arguments. Replaces the arguments loaded from the checkpoint if provided. :param cuda: Whether to move model to cuda. :param logger: A logger. :return: The loaded MoleculeModel. """ debug = logger.debug if logger is not None else print loaded_mpnn_model = torch.load(path, map_location=lambda storage, loc: storage) loaded_state_dict = loaded_mpnn_model["state_dict"] loaded_args = loaded_mpnn_model["args"] model_state_dict = model.state_dict() if loaded_args.number_of_molecules == 1 and current_args.number_of_molecules == 1: encoder_param_names = [ "encoder.encoder.0.W_i.weight", "encoder.encoder.0.W_h.weight", "encoder.encoder.0.W_o.weight", "encoder.encoder.0.W_o.bias", ] if current_args.checkpoint_frzn is not None: # Freeze the MPNN for param_name in encoder_param_names: model_state_dict = overwrite_state_dict( param_name, param_name, loaded_state_dict, model_state_dict) if current_args.frzn_ffn_layers > 0: ffn_param_names = [[f"ffn.{i*3+1}.weight", f"ffn.{i*3+1}.bias"] for i in range(current_args.frzn_ffn_layers)] ffn_param_names = [ item for sublist in ffn_param_names for item in sublist ] # Freeze MPNN and FFN layers for param_name in encoder_param_names + ffn_param_names: model_state_dict = overwrite_state_dict( param_name, param_name, loaded_state_dict, model_state_dict) if current_args.freeze_first_only: debug( "WARNING: --freeze_first_only flag cannot be used with number_of_molecules=1 (flag is ignored)" ) elif loaded_args.number_of_molecules == 1 and current_args.number_of_molecules > 1: # TODO(degraff): these two `if`-blocks can be condensed into one if (current_args.checkpoint_frzn is not None and current_args.freeze_first_only and current_args.frzn_ffn_layers <= 0): # Only freeze first MPNN encoder_param_names = [ "encoder.encoder.0.W_i.weight", "encoder.encoder.0.W_h.weight", "encoder.encoder.0.W_o.weight", "encoder.encoder.0.W_o.bias", ] for param_name in encoder_param_names: model_state_dict = overwrite_state_dict( param_name, param_name, loaded_state_dict, model_state_dict) if ( current_args.checkpoint_frzn is not None and not current_args.freeze_first_only and current_args.frzn_ffn_layers <= 0 ): # Duplicate encoder from frozen checkpoint and overwrite all encoders loaded_encoder_param_names = [ "encoder.encoder.0.W_i.weight", "encoder.encoder.0.W_h.weight", "encoder.encoder.0.W_o.weight", "encoder.encoder.0.W_o.bias", ] * current_args.number_of_molecules model_encoder_param_names = [[( f"encoder.encoder.{mol_num}.W_i.weight", f"encoder.encoder.{mol_num}.W_h.weight", f"encoder.encoder.{mol_num}.W_o.weight", f"encoder.encoder.{mol_num}.W_o.bias", )] for mol_num in range(current_args.number_of_molecules)] model_encoder_param_names = [ item for sublist in model_encoder_param_names for item in sublist ] for loaded_param_name, model_param_name in zip( loaded_encoder_param_names, model_encoder_param_names): model_state_dict = overwrite_state_dict( loaded_param_name, model_param_name, loaded_state_dict, model_state_dict) if current_args.frzn_ffn_layers > 0: raise ValueError( f"Number of molecules from checkpoint_frzn ({loaded_args.number_of_molecules}) " f"must equal current number of molecules ({current_args.number_of_molecules})!" ) elif loaded_args.number_of_molecules > 1 and current_args.number_of_molecules > 1: if loaded_args.number_of_molecules != current_args.number_of_molecules: raise ValueError( f"Number of molecules in checkpoint_frzn ({loaded_args.number_of_molecules}) " f"must either match current model ({current_args.number_of_molecules}) or equal 1." ) if current_args.freeze_first_only: raise ValueError( f"Number of molecules in checkpoint_frzn ({loaded_args.number_of_molecules}) " "must be equal to 1 for freeze_first_only to be used!") if (current_args.checkpoint_frzn is not None) & (not (current_args.frzn_ffn_layers > 0)): encoder_param_names = [[( f"encoder.encoder.{mol_num}.W_i.weight", f"encoder.encoder.{mol_num}.W_h.weight", f"encoder.encoder.{mol_num}.W_o.weight", f"encoder.encoder.{mol_num}.W_o.bias", )] for mol_num in range(current_args.number_of_molecules)] encoder_param_names = [ item for sublist in encoder_param_names for item in sublist ] for param_name in encoder_param_names: model_state_dict = overwrite_state_dict( param_name, param_name, loaded_state_dict, model_state_dict) if current_args.frzn_ffn_layers > 0: encoder_param_names = [[( f"encoder.encoder.{mol_num}.W_i.weight", f"encoder.encoder.{mol_num}.W_h.weight", f"encoder.encoder.{mol_num}.W_o.weight", f"encoder.encoder.{mol_num}.W_o.bias", )] for mol_num in range(current_args.number_of_molecules)] encoder_param_names = [ item for sublist in encoder_param_names for item in sublist ] ffn_param_names = [[f"ffn.{i+3+1}.weight", f"ffn.{i+3+1}.bias"] for i in range(current_args.frzn_ffn_layers)] ffn_param_names = [ item for sublist in ffn_param_names for item in sublist ] for param_name in encoder_param_names + ffn_param_names: model_state_dict = overwrite_state_dict( param_name, param_name, loaded_state_dict, model_state_dict) if current_args.frzn_ffn_layers >= current_args.ffn_num_layers: raise ValueError( f"Number of frozen FFN layers ({current_args.frzn_ffn_layers}) " f"must be less than the number of FFN layers ({current_args.ffn_num_layers})!" ) # Load pretrained weights model.load_state_dict(model_state_dict) return model
def generate_sample_radial( model: torch.nn, features: int = 10, targets: int = 1, iterations: int = 10, start_image: Tensor = None, image_size: int = 64, return_intermediate: bool = True, device: str = "cpu", rotations: int = 2, batch_size: int = 256, all_random: bool = True, ) -> List[Tensor]: """ Create a sample either generated on top of a starting image or from a random image. Return a series of images with the for each time the network is applied. Args : model : The nn model features : The number of points used for interpolation targets : The number of target points where RGB will be predicted iterations : Number of times to apply the network to the image start_image : An normalized image to use as the initial value in range [-1, 1] image_size : The width of the image (assumed square) return_intermediate : Return intermediate values if true device : The device to perform operations on rotations : The number of rotations used by the network batch_size : Set the batch size to run through the network Returns : A list of images for each time the network is applied """ num_pixels = image_size * image_size model.eval() if start_image is None: image = torch.rand([3, image_size, image_size], device=device) * 2 - 1 else: image = copy.deepcopy(start_image) stripe_list = positions_from_mesh( width=image_size, height=image_size, device=device, rotations=rotations, normalize=True, ) # These need to be normalized otherwise the max is the width of the grid stripe_list = torch.cat([val.unsqueeze(0) for val in stripe_list]) indices, target_linear_indices = indices_from_grid(image_size, device=device) result_list = [] for count in range(iterations): logger.info(f"Generating for count {count}") if all_random is True and start_image is None: image = torch.rand([3, image_size, image_size], device=device) * 2 - 1 elif all_random is True and start_image is not None: image = copy.deepcopy(start_image) features_tensor, targets_tensor = random_radial_samples_from_image( img=image, stripe_list=stripe_list, image_size=image_size, feature_pixels=features, indices=indices, target_linear_indices=target_linear_indices, device=device, ) result = model(features_tensor.flatten(1)) image = result.reshape(image_size, image_size, 3).permute(2, 0, 1) result_list.append(image) return result_list
def generate_sample( model: torch.nn, features: int = 10, targets: int = 1, iterations: int = 10, image: Tensor = None, image_size: int = 64, return_intermediate: bool = True, device: str = "cpu", rotations: int = 2, batch_size: int = 256, all_random: bool = True, ) -> List[Tensor]: """ Create a sample either generated on top of a starting image or from a random image. Return a series of images with the for each time the network is applied. Args : model : The nn model features : The number of points used for interpolation targets : The number of target points where RGB will be predicted iterations : Number of times to apply the network to the image image : An unnormalized image to use as the initial value image_size : The width of the image (assumed square) return_intermediate : Return intermediate values if true device : The device to perform operations on rotations : The number of rotations used by the network batch_size : Set the batch size to run through the network Returns : A list of images for each time the network is applied """ num_pixels = image_size * image_size model.eval() if image is None: image = torch.rand([3, image_size, image_size], device=device) * 2 - 1 else: image = (image / 255) * 2 - 1 stripe_list = positions_from_mesh( width=image_size, height=image_size, device=device, rotations=rotations, normalize=True, ) # These need to be normalized otherwise the max is the width of the grid new_vals = torch.cat([val.unsqueeze(0) for val in stripe_list]) result_list = [] for count in range(iterations): logger.info(f"Generating for count {count}") if all_random is True: image = torch.rand([3, image_size, image_size], device=device) * 2 - 1 full_features = torch.cat([image, new_vals]) channels, h, w = full_features.shape full_features = full_features.reshape(channels, -1).permute(1, 0) feature_indices = torch.remainder( torch.randperm(num_pixels * features, device=device), num_pixels ) target_indices = torch.arange(start=0, end=num_pixels, device=device) features_tensor = full_features[feature_indices].reshape(-1, features, channels) targets_tensor = full_features[target_indices].reshape(-1, 1, channels) # Distances are measured relative to target so remove that component features_tensor[:, :, 3:] = features_tensor[:, :, 3:] - targets_tensor[:, :, 3:] result = model(features_tensor.flatten(1)) image = result.reshape(image_size, image_size, 3).permute(2, 0, 1) result_list.append(image) return result_list
def load_frzn_model(model: torch.nn, path: str, current_args: Namespace = None, cuda: bool = None, logger: logging.Logger = None) -> MoleculeModel: """ Loads a model checkpoint. :param path: Path where checkpoint is saved. :param current_args: The current arguments. Replaces the arguments loaded from the checkpoint if provided. :param cuda: Whether to move model to cuda. :param logger: A logger. :return: The loaded MoleculeModel. """ debug = logger.debug if logger is not None else print loaded_mpnn_model = torch.load(path, map_location=lambda storage, loc: storage) loaded_state_dict = loaded_mpnn_model['state_dict'] loaded_args = loaded_mpnn_model['args'] model_state_dict = model.state_dict() if loaded_args.number_of_molecules == 1 & current_args.number_of_molecules == 1: encoder_param_names = [ 'encoder.encoder.0.W_i.weight', 'encoder.encoder.0.W_h.weight', 'encoder.encoder.0.W_o.weight', 'encoder.encoder.0.W_o.bias' ] if current_args.checkpoint_frzn is not None: # Freeze the MPNN for param_name in encoder_param_names: model_state_dict = overwrite_state_dict( param_name, param_name, loaded_state_dict, model_state_dict) if current_args.frzn_ffn_layers > 0: ffn_param_names = [[ 'ffn.' + str(i * 3 + 1) + '.weight', 'ffn.' + str(i * 3 + 1) + '.bias' ] for i in range(current_args.frzn_ffn_layers)] ffn_param_names = [ item for sublist in ffn_param_names for item in sublist ] # Freeze MPNN and FFN layers for param_name in encoder_param_names + ffn_param_names: model_state_dict = overwrite_state_dict( param_name, param_name, loaded_state_dict, model_state_dict) if current_args.freeze_first_only: debug( f'WARNING: --freeze_first_only flag cannot be used with number_of_molecules=1 (flag is ignored)' ) elif (loaded_args.number_of_molecules == 1) & (current_args.number_of_molecules > 1): if (current_args.checkpoint_frzn is not None) & ( current_args.freeze_first_only ) & (not (current_args.frzn_ffn_layers > 0)): # Only freeze first MPNN encoder_param_names = [ 'encoder.encoder.0.W_i.weight', 'encoder.encoder.0.W_h.weight', 'encoder.encoder.0.W_o.weight', 'encoder.encoder.0.W_o.bias' ] for param_name in encoder_param_names: model_state_dict = overwrite_state_dict( param_name, param_name, loaded_state_dict, model_state_dict) if (current_args.checkpoint_frzn is not None) & ( not current_args.freeze_first_only ) & ( not (current_args.frzn_ffn_layers > 0) ): # Duplicate encoder from frozen checkpoint and overwrite all encoders loaded_encoder_param_names = [ 'encoder.encoder.0.W_i.weight', 'encoder.encoder.0.W_h.weight', 'encoder.encoder.0.W_o.weight', 'encoder.encoder.0.W_o.bias' ] * current_args.number_of_molecules model_encoder_param_names = [[ 'encoder.encoder.' + str(mol_num) + '.W_i.weight', 'encoder.encoder.' + str(mol_num) + '.W_h.weight', 'encoder.encoder.' + str(mol_num) + '.W_o.weight', 'encoder.encoder.' + str(mol_num) + '.W_o.bias' ] for mol_num in range(current_args.number_of_molecules)] model_encoder_param_names = [ item for sublist in model_encoder_param_names for item in sublist ] for loaded_param_name, model_param_name in zip( loaded_encoder_param_names, model_encoder_param_names): model_state_dict = overwrite_state_dict( loaded_param_name, model_param_name, loaded_state_dict, model_state_dict) if current_args.frzn_ffn_layers > 0: # Duplicate encoder from frozen checkpoint and overwrite all encoders + FFN layers raise Exception( 'Number of molecules in checkpoint_frzn must be equal to current model for ffn layers to be frozen' ) elif (loaded_args.number_of_molecules > 1) & (current_args.number_of_molecules > 1): if (loaded_args.number_of_molecules) != ( current_args.number_of_molecules): raise Exception( 'Number of molecules in checkpoint_frzn ({}) must match current model ({}) OR equal to 1.' .format(loaded_args.number_of_molecules, current_args.number_of_molecules)) if current_args.freeze_first_only: raise Exception( 'Number of molecules in checkpoint_frzn ({}) must be equal to 1 for freeze_first_only to be used.' .format(loaded_args.number_of_molecules)) if (current_args.checkpoint_frzn is not None) & (not (current_args.frzn_ffn_layers > 0)): encoder_param_names = [[ 'encoder.encoder.' + str(mol_num) + '.W_i.weight', 'encoder.encoder.' + str(mol_num) + '.W_h.weight', 'encoder.encoder.' + str(mol_num) + '.W_o.weight', 'encoder.encoder.' + str(mol_num) + '.W_o.bias' ] for mol_num in range(current_args.number_of_molecules)] encoder_param_names = [ item for sublist in encoder_param_names for item in sublist ] for param_name in encoder_param_names: model_state_dict = overwrite_state_dict( param_name, param_name, loaded_state_dict, model_state_dict) if current_args.frzn_ffn_layers > 0: encoder_param_names = [[ 'encoder.encoder.' + str(mol_num) + '.W_i.weight', 'encoder.encoder.' + str(mol_num) + '.W_h.weight', 'encoder.encoder.' + str(mol_num) + '.W_o.weight', 'encoder.encoder.' + str(mol_num) + '.W_o.bias' ] for mol_num in range(current_args.number_of_molecules)] encoder_param_names = [ item for sublist in encoder_param_names for item in sublist ] ffn_param_names = [[ 'ffn.' + str(i * 3 + 1) + '.weight', 'ffn.' + str(i * 3 + 1) + '.bias' ] for i in range(current_args.frzn_ffn_layers)] ffn_param_names = [ item for sublist in ffn_param_names for item in sublist ] for param_name in encoder_param_names + ffn_param_names: model_state_dict = overwrite_state_dict( param_name, param_name, loaded_state_dict, model_state_dict) if current_args.frzn_ffn_layers >= current_args.ffn_num_layers: raise Exception( 'Number of frozen FFN layers must be less than the number of FFN layers' ) # Load pretrained weights model.load_state_dict(model_state_dict) return model
def __init__(self, module: torch.nn): self.hook = module.register_forward_hook(self.hook_fn)
def update_training_pool_ids(net: torch.nn, training_pool_ids_path: str, all_training_data, device: str): """ training_pool_ids_path: the path to json file which contains images id in training pool. This function will use an acquisition function to collect new 100 imgs into training pool each phase. /Increase the json file 100 more imgs each phase. """ batch_size = 1 training_pool_data = get_pool_data(training_pool_ids_path) all_training_data = get_pool_data(all_training_data) active_pool = set(all_training_data) - set(training_pool_data) dataset = RestrictedDataset(dir_img, dir_mask, list(active_pool)) pool_loader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=4, pin_memory=True) std = [] imgs_id = [] net.eval() n_pool = len(dataset) with tqdm(total=n_pool, desc='STD calculating', unit='batch', leave=False) as pbar: for ind, batch in enumerate(tqdm(pool_loader)): imgs, true_masks = batch['image'], batch['mask'] imgs = imgs.to(device=device, dtype=torch.float32) true_masks = true_masks.to(device=device, dtype=torch.float32) # BHWC true_masks = true_masks[:, :1, :, :] y_pred_samples = [] for i in range(GAUSS_ITERATION): with torch.no_grad(): logits = net(imgs) y_pred = torch.sigmoid(logits) # y_pred = (y_pred > 0.5).float() y_pred = y_pred[:, :1, :, :] y_pred_samples.append( y_pred[:, 0, :, :] ) # y_pred_samples's shape: (inx, bat, H, W ) y_pred_samples = torch.stack(y_pred_samples, dim=0) y_pred_samples = y_pred_samples.type(torch.FloatTensor) mean_y_pred = y_pred_samples.mean(dim=0) # shape: batch, H, W std_y_pred = y_pred_samples.std(dim=0) # shape: batch, H, W grid = torchvision.utils.make_grid(mean_y_pred.unsqueeze(1)) _std = get_segmentation_mask_uncertainty(std_y_pred) _imgs_id = batch['id'] for i in range(batch_size): if i >= len(_std): continue std.extend(_std) imgs_id.extend(_imgs_id) pbar.update() std, imgs_id = zip(*sorted(zip(std, imgs_id))) # order = ascending print("length of std/imgs_id: ", len(std), len(imgs_id)) top_100img = imgs_id[-100:] for i in top_100img: add_image_id_to_pool(i, training_pool_ids_path) print("Adding successfully!")