def contrastive_loss(encoding: torch.Tensor, positive_sample: torch.Tensor, negative_samples: torch.Tensor) -> torch.Tensor: """ Implements the contrastive loss defined in the CMC paper Defined for batch size N, embedding size D, and num_negative_samples K :param encoding: The encoded vector to contrast against samples (N x 1 x D) :param positive_sample: The positive sample to contrast against (N x 1 X D) :param negative_samples: The negative sample to contrast against (N x K x D) :return: The contrastive loss (softmax with correct label in positive location) """ assert len(encoding.size()) == 3, "Expecting encoding shape: (N x 1 x D)" assert len(positive_sample.size()) == 3, "Expecting positive sample shape: (N x 1 x D)" assert len(negative_samples.size()) == 3, "Expecting negative sample shape: (N x K x D)" # Stack positive sample on top of negative all_samples = torch.cat([positive_sample, negative_samples], dim=1) # Compute the "critic" scores from 3.2 Implementing the Critic scores = torch.bmm(all_samples, torch.transpose(encoding, 1, 2)).squeeze(-1) # Compute the contrastive loss targets = torch.zeros(scores.size()[0]).long().to(util.get_project_device()) return F.cross_entropy(scores, targets)
def __init__(self, model: nn.Module, loss_function: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], train_data: DataLoader, validation_data: Optional[DataLoader] = None, optimizer: Optional = None, per_epoch_callbacks: List[Callable] = [], device: torch.device = None, checkpoint_file: str = ""): """ Trains a model given the data, loss function, and possibly and optimizer :param model: The model to train :param loss_function: The loss function to evaluate for training :param train_data: The dataloader to use for training the model :param validation_data: The dataloader used for model validation :param optimizer: [Optional] The optimizer used to update weights of the model :param per_epoch_callbacks: A set of callbacks to be called every epoch. Will be passed the model, data loaders, optimizer, and wandb run. May be used for (e.g. logging images) :param device: The device ((C/G/T)PU) to train and validate on :param checkpoint_file: The location of the saved run to load the model and optimizer from """ # Create optimizer if none passed if optimizer is None: self.optimizer = torch.optim.Adam(model.parameters(recurse=True), lr=0.001) else: self.optimizer = optimizer if device is None: self.device = util.get_project_device() else: self.device = device util.set_project_device(self.device) # Create a learning rate scheduler self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( self.optimizer) # Save the other class parameters self.model = model.to(self.device) self.loss_function = loss_function self.train_data = train_data self.validation_data = validation_data self.callbacks = per_epoch_callbacks # Load the model if checkpoint is given if checkpoint_file != "": model_save_dict, optimizer_save_dict = model_io.load_model_checkpoint( checkpoint_file) self.model.load_state_dict(model_save_dict) if optimizer_save_dict is not None: self.optimizer.load_state_dict(optimizer_save_dict) # Initialize wandb and watch model self.wandb_run = wandb.init(project='RosettaCV', entity='cal-launchpad') wandb.watch(self.model)
def get_cmc_loss_on_dataloader(model: nn.Module, dataloader: DataLoader, loss_fn: Callable) -> torch.Tensor: """ Gets the loss on a dataloader with a different assumed loss signature to that in util.util :param model: The model to evaluate on :param dataloader: The dataloader to evaluate on :param loss_fn: The loss metric to use :return: The loss function applied to the dataloader """ # The device the model is on device = util.get_project_device() total_loss = torch.FloatTensor([0]).to(device) # Get all the encodings for batch in dataloader: # Send to GPU if available for i, view in enumerate(batch): if isinstance(view, torch.Tensor): batch[i] = view.to(device) # Get the positive and negative samples from the batch encodings = model(batch, no_cache=True) enc, pos, neg = get_positive_and_negative_samples(encodings, model) total_loss += loss_fn(enc, pos, neg) return total_loss / len(dataloader)
def decoding_loss(self, inputs: List[torch.Tensor], encodings: List[torch.Tensor]) -> torch.Tensor: """ Samples decoding pathways and adds decoding loss to contrastive loss :param inputs: The inputs to the CMC encoders :param encodings: The encodings that were produced by the model :return: A loss on the decodings sampled """ # Sample decoding pathways decodings = random.sample(self.encode_decode_pairs, k=self.num_decodings) # Perform all relevant decodings decoding_loss = torch.Tensor([0]).to(util.get_project_device()) for encode_view_ind, decode_view_ind in decodings: decoded_view = self.model.views[decode_view_ind] # decoded_view_out = decoded_view.decode(encodings[encode_view_ind]) decoded_view_out = decoded_view.decode(encodings[encode_view_ind], inputs[decode_view_ind]) reconstruction_loss = decoded_view.reconstruction_loss( decoded_view_out, inputs[decode_view_ind]) # Report this reconstruction loss self.wandb_run.log({ f"{self.model.views[encode_view_ind].get_id()} -> " f"{self.model.views[decode_view_ind].get_id()} Reconstruction Loss": reconstruction_loss }) decoding_loss += reconstruction_loss return decoding_loss
def forward(self, x): # Tokenize the text tokenized = self.tokenizer(x, padding=True, return_tensors="pt")['input_ids'] tokenized = tokenized.to(util.get_project_device()) # Forward pass model_outputs = self.model(tokenized).last_hidden_state # Average over sequence dimension model_outputs = torch.mean(model_outputs, dim=-2) model_outputs = model_outputs.view(model_outputs.shape[0], -1) return F.relu(self.linear(model_outputs))
def language_reconstruction_loss(predicted_logits: torch.Tensor, ground_truth_caption: Union[str, List]) -> torch.Tensor: """ The language reconstruction loss, to compare a generated caption to a ground truth :param predicted_logits: The logits over the vocabulary predicted by the language decoder model :param ground_truth_caption: The actual caption for the image :return: The cross entropy loss of the caption and the predicted values """ # Tokenize the caption into BERT's vocabulary if type(ground_truth_caption) == str: ground_truth_tok = tokenizer.encode(ground_truth_caption, padding=True, return_tensors="pt") else: ground_truth_tok = tokenizer.batch_encode_plus(ground_truth_caption, padding=True, return_tensors="pt")['input_ids'] # Reshape ground truth and predicted predicted_logits = predicted_logits.view(-1, predicted_logits.shape[-1]) # Shape[-1] is the vocab size ground_truth_tok = ground_truth_tok.view(-1).to(util.get_project_device()) # Cross entropy ignoring locations that are padded return F.cross_entropy(predicted_logits, ground_truth_tok, ignore_index=PAD_TOKEN)
def forward(self, latent_encoding: torch.Tensor, decoder_inputs: torch.Tensor) -> torch.Tensor: # Change dimension of the input to match cross-attention in BERT latent_encoding = F.relu(self.linear(latent_encoding)) # Tokenize the inputs decoder_inputs = self.tokenizer(decoder_inputs, return_tensors="pt", padding=True).input_ids decoder_inputs = decoder_inputs.to(util.get_project_device()) # Replicate the latent embedding to mimic an encoder output sequence_length = decoder_inputs.size()[1] latent_encoding = latent_encoding.unsqueeze(1) encoder_output = torch.tile(latent_encoding, (1, sequence_length, 1)) # Generate logits for prediction return self.decoder_model(decoder_inputs, encoder_hidden_states=encoder_output).logits
def __init__(self, views: List[View], latent_dim: int, memory_bank_size: int = 200): super(WrapperModel, self).__init__() assert len(views) >= 2, "Must specify at least 2 views!" # Assign views and register submodules self.views = views self.view_encoders = nn.ModuleList([view.encoder for view in views]) self.view_decoders = nn.ModuleList( [view.decoder for view in views if view.decoder is not None]) # Build the memory bank to sample from model_device = util.get_project_device() self.memory_bank = deque() self.memory_bank.extend([ rand_vec.view(-1, latent_dim) for rand_vec in torch.randn((memory_bank_size, latent_dim)).to(model_device) ])
from torchvision import models, datasets from torchvision.transforms import Compose, ToTensor, Resize from data_loader.MultiviewDatasets import MultiviewDataset, identity_view, get_coco_captions, get_grayscale_view, get_inverted_view, get_complementary_view, get_saturated_view from torch.utils.data import DataLoader resnet_feature_size = 512 if __name__ == "__main__": from models.CMC.ResNetEncoder import ResNetEncoder from models.CMC.Decoder import Decoder from models.CMC.language_models import TextEncoder, TextDecoder # Ensure aspect ratio is 0.75 image_size = (480, 640) device = util.get_project_device() latent_dim = 512 # Define encoders, decoders, and views image_encoder = ResNetEncoder(device, latent_dim=latent_dim) image_decoder = Decoder(*image_size) image_view = View(image_encoder, image_decoder, "Image", reconstruction_loss=l2_reconstruction_loss) # grayscale_encoder = ResNetEncoder(device, latent_dim=latent_dim) # grayscale_decoder = Decoder(*image_size) # grayscale_view = View(grayscale_encoder, grayscale_decoder, "Grayscale", reconstruction_loss=l2_reconstruction_loss) inverted_encoder = ResNetEncoder(device, latent_dim=latent_dim)