class QMPN(nn.Module): def __init__(self, opt): super(QMPN, self).__init__() self.device = opt.device self.input_dims = opt.input_dims self.total_input_dim = sum(self.input_dims) if type(opt.embed_dims) == int: self.embed_dims = [opt.embed_dims] self.embed_dims = [int(s) for s in opt.embed_dims.split(',')] self.speaker_num = opt.speaker_num self.dataset_name = opt.dataset_name # MELD data # The one-hot vectors are not the global user ID if self.dataset_name.lower() == 'meld': self.speaker_num = 1 self.n_classes = opt.output_dim self.projections = nn.ModuleList([nn.Linear(dim, embed_dim) for dim, embed_dim in zip(self.input_dims,self.embed_dims)]) self.multiply = ComplexMultiply() self.outer = QOuter() self.norm = L2Norm(dim = -1) self.mixture = QMixture(device = self.device) self.output_cell_dim = opt.output_cell_dim self.phase_embeddings = nn.ModuleList([PositionEmbedding(embed_dim, input_dim = self.speaker_num, device = self.device) for embed_dim in self.embed_dims]) self.out_dropout_rate = opt.out_dropout_rate self.num_layers = opt.num_layers self.state_product = QProduct() self.recurrent_dim = 1 for embed_dim in self.embed_dims: self.recurrent_dim = self.recurrent_dim*embed_dim self.recurrent_cells = nn.ModuleList([QRNNCell(self.recurrent_dim, device = self.device)]*self.num_layers) self.measurement = QMeasurement(self.recurrent_dim) self.fc_out = SimpleNet(self.recurrent_dim, self.output_cell_dim, self.out_dropout_rate,self.n_classes, output_activation = nn.Tanh()) def get_params(self): unitary_params = [] remaining_params = [] for i in range(self.num_layers): unitary_params.append(self.recurrent_cells[i].unitary_x) unitary_params.append(self.recurrent_cells[i].unitary_h) remaining_params.append(self.recurrent_cells[i].Lambda) remaining_params.extend(list(self.projections.parameters())) remaining_params.extend(list(self.phase_embeddings.parameters())) for i in range(self.num_layers): remaining_params.append(self.recurrent_cells[i].Lambda) unitary_params.extend(list(self.measurement.parameters())) remaining_params.extend(list(self.fc_out.parameters())) return unitary_params, remaining_params def forward(self, in_modalities): smask = in_modalities[-2] # Speaker ids in_modalities = in_modalities[:-2] batch_size = in_modalities[0].shape[0] time_stamps = in_modalities[0].shape[1] # Project All modalities of each utterance to the same space utterance_reps = [nn.ReLU()(projection(x)) for x, projection in zip(in_modalities,self.projections)] amplitudes = [F.normalize(rep, dim = -1) for rep in utterance_reps] phases = [phase_embed(smask.argmax(dim = -1)) for phase_embed in self.phase_embeddings] unimodal_pure = [self.multiply([phase, amplitude]) for phase, amplitude in zip(phases,amplitudes)] multimodal_pure = self.state_product(unimodal_pure) in_states = self.outer(multimodal_pure) # Take the amplitudes # multiply with modality specific vectors to construct weights #weights = [self.norm(rep) for rep in utterance_reps] #weights = F.softmax(torch.cat(weights, dim = -1), dim = -1) #unimodal_matrices = [self.outer(s) for s in unimodal_pure] #in_states = self.mixture([unimodal_matrices, weights]) for l in range(self.num_layers): # Initialize the cell h h_r = torch.stack(batch_size*[torch.eye(self.recurrent_dim)/self.recurrent_dim],dim =0) h_i = torch.zeros_like(h_r) h = [h_r.to(self.device),h_i.to(self.device)] all_h = [] for t in range(time_stamps): h = self.recurrent_cells[l](in_states[t],h) all_h.append(self.activation(h)) in_states = all_h output = [] for _h in in_states: measurement_probs = self.measurement(_h) _output = self.fc_out(measurement_probs) output.append(_output) output = torch.stack(output, dim=-2) log_prob = F.log_softmax(output, 2) # batch, seq_len, n_classes return log_prob
class QAttN(nn.Module): def __init__(self, opt): super(QAttN, self).__init__() self.device = opt.device self.input_dims = opt.input_dims self.total_input_dim = sum(self.input_dims) self.embed_dim = opt.embed_dim self.speaker_num = opt.speaker_num self.n_classes = opt.output_dim self.output_cell_dim = opt.output_cell_dim self.dataset_name = opt.dataset_name self.projections = nn.ModuleList( [nn.Linear(dim, self.embed_dim) for dim in self.input_dims]) if self.dataset_name.lower() == 'meld': self.speaker_num = 1 self.multiply = ComplexMultiply() self.outer = QOuter() self.norm = L2Norm(dim=-1) self.mixture = QMixture(device=self.device) self.phase_embeddings = nn.ModuleList([ PositionEmbedding(self.embed_dim, input_dim=1, device=self.device) ] * len(self.input_dims)) self.out_dropout_rate = opt.out_dropout_rate self.attention = QAttention() self.num_modalities = len(self.input_dims) #self.out_dropout = QDropout(p=self.out_dropout_rate) #self.dense = QDense(self.embed_dim, self.n_classes) self.measurement_type = opt.measurement_type if opt.measurement_type == 'quantum': self.measurement = QMeasurement(self.embed_dim) self.fc_out = SimpleNet(self.embed_dim * self.num_modalities, self.output_cell_dim, self.out_dropout_rate, self.n_classes, output_activation=nn.Tanh()) else: self.measurement = ComplexMeasurement(self.embed_dim * self.num_modalities, units=20) self.fc_out = SimpleNet(20, self.output_cell_dim, self.out_dropout_rate, self.n_classes, output_activation=nn.Tanh()) def get_params(self): unitary_params = [] remaining_params = [] remaining_params.extend(list(self.projections.parameters())) remaining_params.extend(list(self.phase_embeddings.parameters())) if self.measurement_type == 'quantum': unitary_params.extend(list(self.measurement.parameters())) else: remaining_params.extend(list(self.measurement.parameters())) remaining_params.extend(list(self.fc_out.parameters())) return unitary_params, remaining_params def forward(self, in_modalities): smask = in_modalities[-2] # Speaker ids in_modalities = in_modalities[:-2] batch_size = in_modalities[0].shape[0] time_stamps = in_modalities[0].shape[1] # Project All modalities of each utterance to the same space #utterance_reps = [nn.Tanh()(projection(x)) for x, projection in zip(in_modalities,self.projections)] utterance_reps = [ nn.ReLU()(projection(x)) for x, projection in zip(in_modalities, self.projections) ] # Take the amplitudes # multiply with modality specific vectors to construct weights weights = [self.norm(rep) for rep in utterance_reps] amplitudes = [F.normalize(rep, dim=-1) for rep in utterance_reps] phases = [ phase_embed(smask.argmax(dim=-1)) for phase_embed in self.phase_embeddings ] unimodal_pure = [ self.multiply([phase, amplitude]) for phase, amplitude in zip(phases, amplitudes) ] unimodal_matrices = [self.outer(s) for s in unimodal_pure] probs = [] # For each modality # we mix the remaining modalities as queries (to_be_measured systems) # and treat the modality features as keys (measurement operators) for ind in range(self.num_modalities): # Obtain mixed states for the rest modalities other_weights = [ weights[i] for i in range(self.num_modalities) if not i == ind ] mixture_weights = F.softmax(torch.cat(other_weights, dim=-1), dim=-1) other_states = [ unimodal_matrices[i] for i in range(self.num_modalities) if not i == ind ] q_states = self.mixture([other_states, mixture_weights]) # Obtain pure states and weights for the modality of interest k_weights = weights[ind] k_states = unimodal_pure[ind] # Compute cross-modal interactions, output being a list of post-measurement states in_states = self.attention(q_states, k_states, k_weights) # Apply measurement to each output state output = [] for _h in in_states: measurement_probs = self.measurement(_h) output.append(measurement_probs) probs.append(output) # Concatenate the measurement probabilities per-time-stamp concat_probs = [ self.fc_out(torch.cat(output_t, dim=-1)) for output_t in zip(*probs) ] concat_probs = torch.stack(concat_probs, dim=-2) log_prob = F.log_softmax(concat_probs, 2) # batch, seq_len, n_classes return log_prob
def main(params, config): data_pedestal = PedestalDataset(config) train_loader, validation_loader = split_dataset(data_pedestal, params['batch_size']) if config['experiment']['load_model'] != None: PATH = config['experiment']['load_model'] checkpoint = torch.load(PATH) # Load Model net = SimpleNet(params, config) net.load_state_dict(checkpoint['model_state_dict']) # Load Optimizer optimizer = model_utils.map_optimizer(params['optimizer'], net.parameters(), 0.0) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) # Assign Loss Function loss_func = model_utils.map_loss_func(params['loss']) # Set EPOCH and LOSS for retraining epoch = checkpoint['epoch'] loss = checkpoint['loss'] else: net = SimpleNet(params, config) optimizer = model_utils.map_optimizer(params['optimizer'], net.parameters(), params['learning_rate']) loss_func = model_utils.map_loss_func(params['loss']) epochs = config['epochs'] last_results = [] metrics = {} for epoch in range(epochs): # TRAINING net.train() for i, batch in enumerate(train_loader): optimizer.zero_grad() inputs, targets = batch['input'], batch['target'] output = net(inputs) loss = loss_func(output, targets) loss.backward() optimizer.step() # Validation if epoch % 5 == 4: net.eval() max_error = 0.0 for i, batch in enumerate(validation_loader): inputs, targets = batch['input'], batch['target'] output = net(inputs) MSE = torch.sum((output - targets)** 2) / (len(output) * params['batch_size']) max_error = max(MSE, max_error) score = -math.log10(max_error) # print(epoch, score) if epoch > epochs - 5: last_results.append(score) final_score = min(last_results) metrics['default'] = final_score if config['experiment']['save_model'] is not None: PATH = config['experiment']['save_model'] # save mode torch.save( { 'epoch': epoch, 'model_state_dict': net.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': loss }, PATH)
class QMNAblation(nn.Module): def __init__(self, opt): super(QMNAblation, self).__init__() self.device = opt.device self.input_dims = opt.input_dims self.total_input_dim = sum(self.input_dims) self.embed_dim = opt.embed_dim self.speaker_num = opt.speaker_num self.dataset_name = opt.dataset_name # MELD data if self.dataset_name.lower() == 'meld': self.speaker_num = 1 self.n_classes = opt.output_dim self.input_concat = opt.input_concat self.zero_phase = opt.zero_phase self.measurement_type = opt.measurement_type self.classical_recurrent = opt.classical_recurrent self.quantum_recurrent = opt.quantum_recurrent if self.input_concat: self.projections = nn.ModuleList( [nn.Linear(self.total_input_dim, self.embed_dim)]) elif self.classical_recurrent: self.projections = nn.ModuleList( [nn.GRU(dim, self.embed_dim, 1) for dim in self.input_dims]) else: self.projections = nn.ModuleList( [nn.Linear(dim, self.embed_dim) for dim in self.input_dims]) self.multiply = ComplexMultiply() self.outer = QOuter() self.norm = L2Norm(dim=-1) self.mixture = QMixture(device=self.device) self.output_cell_dim = opt.output_cell_dim self.phase_embeddings = nn.ModuleList([PositionEmbedding(self.embed_dim, input_dim = self.speaker_num,\ zero_phase = self.zero_phase, device = self.device)]*len(self.input_dims)) self.out_dropout_rate = opt.out_dropout_rate self.num_layers = opt.num_layers self.recurrent_cells = nn.ModuleList( [QRNNCell(self.embed_dim, device=self.device)] * self.num_layers) self.out_dropout = QDropout(p=self.out_dropout_rate) self.activation = QActivation(scale_factor=1, beta=0.8) self.measurement_type = opt.measurement_type if self.measurement_type == 'quantum': self.measurement = QMeasurement(self.embed_dim) self.fc_out = SimpleNet(self.embed_dim, self.output_cell_dim, self.out_dropout_rate, self.n_classes, output_activation=nn.Tanh()) elif self.measurement_type == 'flatten': self.fc_out = SimpleNet(self.embed_dim, self.output_cell_dim, self.out_dropout_rate, self.n_classes, output_activation=nn.Tanh()) else: self.measurement = ComplexMeasurement(self.embed_dim, units=self.embed_dim) self.fc_out = SimpleNet(self.embed_dim, self.output_cell_dim, self.out_dropout_rate, self.n_classes, output_activation=nn.Tanh()) def get_params(self): unitary_params = [] remaining_params = [] for i in range(self.num_layers): unitary_params.append(self.recurrent_cells[i].unitary_x) unitary_params.append(self.recurrent_cells[i].unitary_h) remaining_params.append(self.recurrent_cells[i].Lambda) remaining_params.extend(list(self.projections.parameters())) remaining_params.extend(list(self.phase_embeddings.parameters())) for i in range(self.num_layers): remaining_params.append(self.recurrent_cells[i].Lambda) if self.measurement_type == 'quantum': unitary_params.extend(list(self.measurement.parameters())) else: remaining_params.extend(list(self.measurement.parameters())) remaining_params.extend(list(self.fc_out.parameters())) return unitary_params, remaining_params def forward(self, in_modalities): smask = in_modalities[-2] # Speaker ids in_modalities = in_modalities[:-2] batch_size = in_modalities[0].shape[0] time_stamps = in_modalities[0].shape[1] # Project All modalities of each utterance to the same space #utterance_reps = [nn.Tanh()(projection(x)) for x, projection in zip(in_modalities,self.projections)] if self.input_concat: in_modalities = torch.cat(in_modalities, dim=-1) if self.classical_recurrent: utterance_reps = [ nn.ReLU()(projection(x)[0]) for x, projection in zip(in_modalities, self.projections) ] else: utterance_reps = [ nn.ReLU()(projection(x)) for x, projection in zip(in_modalities, self.projections) ] # Take the amplitudes # multiply with modality specific vectors to construct weights weights = [self.norm(rep) for rep in utterance_reps] weights = F.softmax(torch.cat(weights, dim=-1), dim=-1) amplitudes = [F.normalize(rep, dim=-1) for rep in utterance_reps] phases = [ phase_embed(smask.argmax(dim=-1)) for phase_embed in self.phase_embeddings ] unimodal_pure = [ self.multiply([phase, amplitude]) for phase, amplitude in zip(phases, amplitudes) ] unimodal_matrices = [self.outer(s) for s in unimodal_pure] in_states = self.mixture([unimodal_matrices, weights]) if self.quantum_recurrent: for l in range(self.num_layers): # Initialize the cell h h_r = torch.stack(batch_size * [torch.eye(self.embed_dim) / self.embed_dim], dim=0) h_i = torch.zeros_like(h_r) h = [h_r.to(self.device), h_i.to(self.device)] all_h = [] for t in range(time_stamps): h = self.recurrent_cells[l](in_states[t], h) all_h.append(self.activation(h)) in_states = all_h output = [] for _h in in_states: # _h = self.out_dropout(_h) # _h = self.dense(_h) # measurement_probs = self.measurement(_h) if self.measurement_type == 'flatten': _output = self.fc_out(_h[0].reshape(batch_size, -1)) else: measurement_probs = self.measurement(_h) _output = self.fc_out(measurement_probs) output.append(_output) output = torch.stack(output, dim=-2) log_prob = F.log_softmax(output, 2) # batch, seq_len, n_classes return log_prob
class QMultiTask(nn.Module): def __init__(self, opt): super(QMultiTask, self).__init__() self.device = opt.device self.input_dims = opt.input_dims self.total_input_dim = sum(self.input_dims) self.embed_dim = opt.embed_dim self.speaker_num = opt.speaker_num self.dataset_name = opt.dataset_name self.features = opt.features # MELD data # The one-hot vectors are not the global user ID if self.dataset_name.lower() == 'meld': self.speaker_num = 1 self.n_classes_emo = opt.output_dim_emo self.n_classes_act = opt.output_dim_act self.projections = nn.ModuleList( [nn.Linear(dim, self.embed_dim) for dim in self.input_dims]) self.multiply = ComplexMultiply() self.outer = QOuter() self.norm = L2Norm(dim=-1) self.mixture = QMixture(device=self.device) self.output_cell_dim = opt.output_cell_dim self.phase_embeddings = nn.ModuleList([ PositionEmbedding( self.embed_dim, input_dim=self.speaker_num, device=self.device) ] * len(self.input_dims)) self.out_dropout_rate = opt.out_dropout_rate self.measurement_emotion = QMeasurement(self.embed_dim) self.measurement_act = QMeasurement(self.embed_dim) self.fc_out_emo = SimpleNet(self.embed_dim, self.output_cell_dim, self.out_dropout_rate, self.n_classes_emo, output_activation=nn.Tanh()) self.fc_out_act = SimpleNet(self.embed_dim, self.output_cell_dim, self.out_dropout_rate, self.n_classes_act, output_activation=nn.Tanh()) self.num_layers = opt.num_layers #self.rnn=nn.ModuleList([QRNNCell(self.embed_dim, device = self.device)]*self.num_layers) self.RNNs = nn.ModuleList([ QRNN(self.embed_dim, self.device, self.num_layers) for i in range(len(opt.features)) ]) self.rnn_outer = QOuter() self.action_qrnn = QRNN(self.embed_dim, self.device, self.num_layers) def get_params(self): unitary_params = [] remaining_params = [] for i in range(len(self.features)): qrnn = self.RNNs[i] for k in range(self.num_layers): unitary_params.append(qrnn.recurrent_cells[k].unitary_x) unitary_params.append(qrnn.recurrent_cells[k].unitary_h) remaining_params.append(qrnn.recurrent_cells[k].Lambda) for k in range(self.num_layers): unitary_params.append( self.action_qrnn.recurrent_cells[k].unitary_x) unitary_params.append( self.action_qrnn.recurrent_cells[k].unitary_h) remaining_params.append(self.action_qrnn.recurrent_cells[k].Lambda) unitary_params.extend(list(self.measurement_act.parameters())) unitary_params.extend(list(self.measurement_emotion.parameters())) remaining_params.extend(list(self.projections.parameters())) remaining_params.extend(list(self.phase_embeddings.parameters())) remaining_params.extend(list(self.fc_out_act.parameters())) remaining_params.extend(list(self.fc_out_emo.parameters())) return unitary_params, remaining_params def forward(self, in_modalities): smask = in_modalities[-2] # Speaker ids in_modalities = in_modalities[:-2] batch_size = in_modalities[0].shape[0] time_stamps = in_modalities[0].shape[1] # Project All modalities of each utterance to the same space utterance_reps = [ nn.ReLU()(projection(x)) for x, projection in zip(in_modalities, self.projections) ] ## # Take the amplitudes # multiply with modality specific vectors to construct weights amplitudes = [F.normalize(rep, dim=-1) for rep in utterance_reps] phases = [ phase_embed(smask.argmax(dim=-1)) for phase_embed in self.phase_embeddings ] weights = [self.norm(rep) for rep in utterance_reps] weights = F.softmax(torch.cat(weights, dim=-1), dim=-1) unimodal_pure = [ self.multiply([phase, amplitude]) for phase, amplitude in zip(phases, amplitudes) ] unimodal_matrices = [self.outer(s) for s in unimodal_pure] rnn_unimodal_data = [ rnn(data) for data, rnn in zip(unimodal_matrices, self.RNNs) ] ## # weights = [self.norm(rep) for rep in rnn_unimodal_data] # weights = F.softmax(torch.cat(weights, dim = -1), dim = -1) emo_states = self.mixture([rnn_unimodal_data, weights]) action_states = self.action_qrnn(emo_states) ###emotion classifier### output_emo = [] for _h in emo_states: measurement_probs = self.measurement_emotion(_h) _output = self.fc_out_emo(measurement_probs) output_emo.append(_output) output_e = torch.stack(output_emo, dim=-2) log_prob_e = F.log_softmax(output_e, 2) # batch, seq_len, n_classes ###action classifier### output_act = [] for _h in emo_states: measurement_probs = self.measurement_act(_h) _output = self.fc_out_act(measurement_probs) output_act.append(_output) output_a = torch.stack(output_act, dim=-2) log_prob_a = F.log_softmax(output_a, 2) # batch, seq_len, n_classes return log_prob_e, log_prob_a
def main(params, config): dataset = PedestalDataset(config) train_loader, validation_loader = split_dataset(dataset, params['batch_size']) if config['experiment']['load_model'] != None: PATH = config['experiment']['load_model'] checkpoint = torch.load(PATH) # Load Model net = SimpleNet(params, config) net.load_state_dict(checkpoint['model_state_dict']) # Load Optimizer optimizer = model_utils.map_optimizer(params['optimizer'], net.parameters(), 0.0) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) # Assign Loss Function loss_func = model_utils.map_loss_func(params['loss']) # Set EPOCH and LOSS for retraining epoch = checkpoint['epoch'] loss = checkpoint['loss'] else: net = SimpleNet(params, config) optimizer = model_utils.map_optimizer(params['optimizer'], net.parameters(), params['learning_rate']) loss_func = model_utils.map_loss_func(params['loss']) metrics = {} target_norms = dataset.targets_norms input_norms = dataset.inputs_norms target_list = dataset.target_params input_list = dataset.input_params net.eval() outputs = [] actual_array = [] scaled_list = [] save_path = config['experiment']['name'] for i, batch in enumerate(validation_loader): inputs, targets = batch['input'], batch['target'] for val in inputs: output = net(val).detach().numpy() output = denormalize(output, target_list, target_norms) outputs.append(output[0]) normed_vals = denormalize(val.numpy(), input_list, input_norms) for tar in targets: denorm_targ = denormalize(tar.numpy(), target_list, target_norms) actual_array.append(denorm_targ[0]) if config['experiment']['target'] == 'density': for i, batch in enumerate(validation_loader): inputs, targets = batch['input'], batch['target'] for val in inputs: normed_vals = denormalize(val.numpy(), input_list, input_norms) scaled_vals = fitted_scale(normed_vals) scaled_list.append(scaled_vals) plt.scatter(actual_array, scaled_list, label='Scale Law') plt.scatter(actual_array, actual_array, label='Actual') plt.scatter(actual_array, outputs, label='NN') plt.legend() plt.ylabel('Predicted') plt.xlabel('Actual Density Height') plt.ylim(0, 12) plt.title('Neural Network vs Scaling Law') plt.savefig('./results/' + save_path) plt.show()