Exemplo n.º 1
0
class QMPN(nn.Module):
    def __init__(self, opt):
        super(QMPN, self).__init__()
        self.device = opt.device    
        self.input_dims = opt.input_dims
        self.total_input_dim = sum(self.input_dims)
        
        if type(opt.embed_dims) == int:
            self.embed_dims = [opt.embed_dims]
        self.embed_dims = [int(s) for s in opt.embed_dims.split(',')]
        
        self.speaker_num = opt.speaker_num
        self.dataset_name = opt.dataset_name
        
        # MELD data 
        # The one-hot vectors are not the global user ID
        if self.dataset_name.lower() == 'meld':
            self.speaker_num = 1
        self.n_classes = opt.output_dim
        self.projections = nn.ModuleList([nn.Linear(dim, embed_dim) for dim, embed_dim in zip(self.input_dims,self.embed_dims)])
        
        self.multiply = ComplexMultiply()
        self.outer = QOuter()
        self.norm = L2Norm(dim = -1)
        self.mixture = QMixture(device = self.device)
        self.output_cell_dim = opt.output_cell_dim
        self.phase_embeddings = nn.ModuleList([PositionEmbedding(embed_dim, input_dim = self.speaker_num, device = self.device) for embed_dim in self.embed_dims]) 
        self.out_dropout_rate = opt.out_dropout_rate
        self.num_layers = opt.num_layers
        self.state_product = QProduct()
        self.recurrent_dim = 1
        for embed_dim in self.embed_dims:
            self.recurrent_dim = self.recurrent_dim*embed_dim
        self.recurrent_cells = nn.ModuleList([QRNNCell(self.recurrent_dim, device = self.device)]*self.num_layers)
        
        self.measurement = QMeasurement(self.recurrent_dim)
        self.fc_out = SimpleNet(self.recurrent_dim, self.output_cell_dim,
                                self.out_dropout_rate,self.n_classes,
                                output_activation = nn.Tanh())

        
    def get_params(self):
    
        unitary_params = []
        remaining_params = []
        for i in range(self.num_layers):
            unitary_params.append(self.recurrent_cells[i].unitary_x)
            unitary_params.append(self.recurrent_cells[i].unitary_h)
            remaining_params.append(self.recurrent_cells[i].Lambda)
        
        remaining_params.extend(list(self.projections.parameters()))
        remaining_params.extend(list(self.phase_embeddings.parameters()))
        for i in range(self.num_layers):
            remaining_params.append(self.recurrent_cells[i].Lambda)
            
        unitary_params.extend(list(self.measurement.parameters()))
        remaining_params.extend(list(self.fc_out.parameters()))
            
        return unitary_params, remaining_params
    
    def forward(self, in_modalities):
        smask = in_modalities[-2] # Speaker ids
        in_modalities = in_modalities[:-2]
        
        batch_size = in_modalities[0].shape[0]
        time_stamps = in_modalities[0].shape[1]
        
        # Project All modalities of each utterance to the same space        
        utterance_reps = [nn.ReLU()(projection(x)) for x, projection in zip(in_modalities,self.projections)] 
        
        amplitudes = [F.normalize(rep, dim = -1) for rep in utterance_reps]
        phases = [phase_embed(smask.argmax(dim = -1)) for phase_embed in self.phase_embeddings]
        
        unimodal_pure = [self.multiply([phase, amplitude]) for phase, amplitude in zip(phases,amplitudes)]
        multimodal_pure = self.state_product(unimodal_pure)
        in_states = self.outer(multimodal_pure)
        
        # Take the amplitudes
        # multiply with modality specific vectors to construct weights
        #weights = [self.norm(rep) for rep in utterance_reps]
        #weights = F.softmax(torch.cat(weights, dim = -1), dim = -1)
        


        #unimodal_matrices = [self.outer(s) for s in unimodal_pure]
        
        #in_states = self.mixture([unimodal_matrices, weights])
        
        for l in range(self.num_layers):
            # Initialize the cell h
            h_r = torch.stack(batch_size*[torch.eye(self.recurrent_dim)/self.recurrent_dim],dim =0)
            h_i = torch.zeros_like(h_r)
            h = [h_r.to(self.device),h_i.to(self.device)]
            all_h = []
            for t in range(time_stamps):
                h = self.recurrent_cells[l](in_states[t],h)
                all_h.append(self.activation(h))
            in_states = all_h

        output = []
        
        for _h in in_states:
            measurement_probs = self.measurement(_h)
            _output = self.fc_out(measurement_probs)
            output.append(_output)
            
            
        output = torch.stack(output, dim=-2)
        log_prob = F.log_softmax(output, 2) # batch, seq_len,  n_classes

        return log_prob
Exemplo n.º 2
0
class QAttN(nn.Module):
    def __init__(self, opt):
        super(QAttN, self).__init__()
        self.device = opt.device
        self.input_dims = opt.input_dims
        self.total_input_dim = sum(self.input_dims)
        self.embed_dim = opt.embed_dim
        self.speaker_num = opt.speaker_num
        self.n_classes = opt.output_dim
        self.output_cell_dim = opt.output_cell_dim
        self.dataset_name = opt.dataset_name

        self.projections = nn.ModuleList(
            [nn.Linear(dim, self.embed_dim) for dim in self.input_dims])
        if self.dataset_name.lower() == 'meld':
            self.speaker_num = 1
        self.multiply = ComplexMultiply()
        self.outer = QOuter()
        self.norm = L2Norm(dim=-1)
        self.mixture = QMixture(device=self.device)
        self.phase_embeddings = nn.ModuleList([
            PositionEmbedding(self.embed_dim, input_dim=1, device=self.device)
        ] * len(self.input_dims))
        self.out_dropout_rate = opt.out_dropout_rate
        self.attention = QAttention()
        self.num_modalities = len(self.input_dims)
        #self.out_dropout = QDropout(p=self.out_dropout_rate)

        #self.dense = QDense(self.embed_dim, self.n_classes)
        self.measurement_type = opt.measurement_type
        if opt.measurement_type == 'quantum':
            self.measurement = QMeasurement(self.embed_dim)
            self.fc_out = SimpleNet(self.embed_dim * self.num_modalities,
                                    self.output_cell_dim,
                                    self.out_dropout_rate,
                                    self.n_classes,
                                    output_activation=nn.Tanh())

        else:
            self.measurement = ComplexMeasurement(self.embed_dim *
                                                  self.num_modalities,
                                                  units=20)
            self.fc_out = SimpleNet(20,
                                    self.output_cell_dim,
                                    self.out_dropout_rate,
                                    self.n_classes,
                                    output_activation=nn.Tanh())

    def get_params(self):
        unitary_params = []
        remaining_params = []

        remaining_params.extend(list(self.projections.parameters()))
        remaining_params.extend(list(self.phase_embeddings.parameters()))

        if self.measurement_type == 'quantum':
            unitary_params.extend(list(self.measurement.parameters()))
        else:
            remaining_params.extend(list(self.measurement.parameters()))
        remaining_params.extend(list(self.fc_out.parameters()))

        return unitary_params, remaining_params

    def forward(self, in_modalities):
        smask = in_modalities[-2]  # Speaker ids
        in_modalities = in_modalities[:-2]

        batch_size = in_modalities[0].shape[0]
        time_stamps = in_modalities[0].shape[1]

        # Project All modalities of each utterance to the same space
        #utterance_reps = [nn.Tanh()(projection(x)) for x, projection in zip(in_modalities,self.projections)]

        utterance_reps = [
            nn.ReLU()(projection(x))
            for x, projection in zip(in_modalities, self.projections)
        ]

        # Take the amplitudes
        # multiply with modality specific vectors to construct weights
        weights = [self.norm(rep) for rep in utterance_reps]

        amplitudes = [F.normalize(rep, dim=-1) for rep in utterance_reps]

        phases = [
            phase_embed(smask.argmax(dim=-1))
            for phase_embed in self.phase_embeddings
        ]
        unimodal_pure = [
            self.multiply([phase, amplitude])
            for phase, amplitude in zip(phases, amplitudes)
        ]
        unimodal_matrices = [self.outer(s) for s in unimodal_pure]

        probs = []

        # For each modality
        # we mix the remaining modalities as queries (to_be_measured systems)
        # and treat the modality features as keys (measurement operators)

        for ind in range(self.num_modalities):

            # Obtain mixed states for the rest modalities
            other_weights = [
                weights[i] for i in range(self.num_modalities) if not i == ind
            ]
            mixture_weights = F.softmax(torch.cat(other_weights, dim=-1),
                                        dim=-1)
            other_states = [
                unimodal_matrices[i] for i in range(self.num_modalities)
                if not i == ind
            ]
            q_states = self.mixture([other_states, mixture_weights])

            # Obtain pure states and weights for the modality of interest
            k_weights = weights[ind]
            k_states = unimodal_pure[ind]

            # Compute cross-modal interactions, output being a list of post-measurement states
            in_states = self.attention(q_states, k_states, k_weights)

            # Apply measurement to each output state
            output = []
            for _h in in_states:
                measurement_probs = self.measurement(_h)
                output.append(measurement_probs)

            probs.append(output)

        # Concatenate the measurement probabilities per-time-stamp
        concat_probs = [
            self.fc_out(torch.cat(output_t, dim=-1))
            for output_t in zip(*probs)
        ]
        concat_probs = torch.stack(concat_probs, dim=-2)

        log_prob = F.log_softmax(concat_probs, 2)  # batch, seq_len,  n_classes

        return log_prob
Exemplo n.º 3
0
def main(params, config):
    data_pedestal = PedestalDataset(config)

    train_loader, validation_loader = split_dataset(data_pedestal,
                                                    params['batch_size'])

    if config['experiment']['load_model'] != None:
        PATH = config['experiment']['load_model']
        checkpoint = torch.load(PATH)
        # Load Model
        net = SimpleNet(params, config)
        net.load_state_dict(checkpoint['model_state_dict'])

        # Load Optimizer
        optimizer = model_utils.map_optimizer(params['optimizer'],
                                              net.parameters(), 0.0)
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

        # Assign Loss Function
        loss_func = model_utils.map_loss_func(params['loss'])

        # Set EPOCH and LOSS for retraining
        epoch = checkpoint['epoch']
        loss = checkpoint['loss']

    else:
        net = SimpleNet(params, config)
        optimizer = model_utils.map_optimizer(params['optimizer'],
                                              net.parameters(),
                                              params['learning_rate'])
        loss_func = model_utils.map_loss_func(params['loss'])

    epochs = config['epochs']

    last_results = []
    metrics = {}
    for epoch in range(epochs):

        # TRAINING
        net.train()
        for i, batch in enumerate(train_loader):
            optimizer.zero_grad()
            inputs, targets = batch['input'], batch['target']
            output = net(inputs)
            loss = loss_func(output, targets)
            loss.backward()
            optimizer.step()

        # Validation
        if epoch % 5 == 4:
            net.eval()
            max_error = 0.0

            for i, batch in enumerate(validation_loader):
                inputs, targets = batch['input'], batch['target']
                output = net(inputs)
                MSE = torch.sum((output - targets)**
                                2) / (len(output) * params['batch_size'])
                max_error = max(MSE, max_error)
                score = -math.log10(max_error)
                # print(epoch, score)

        if epoch > epochs - 5:
            last_results.append(score)

    final_score = min(last_results)
    metrics['default'] = final_score

    if config['experiment']['save_model'] is not None:
        PATH = config['experiment']['save_model']
        # save mode
        torch.save(
            {
                'epoch': epoch,
                'model_state_dict': net.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': loss
            }, PATH)
Exemplo n.º 4
0
class QMNAblation(nn.Module):
    def __init__(self, opt):
        super(QMNAblation, self).__init__()
        self.device = opt.device
        self.input_dims = opt.input_dims
        self.total_input_dim = sum(self.input_dims)
        self.embed_dim = opt.embed_dim
        self.speaker_num = opt.speaker_num
        self.dataset_name = opt.dataset_name

        # MELD data
        if self.dataset_name.lower() == 'meld':
            self.speaker_num = 1
        self.n_classes = opt.output_dim
        self.input_concat = opt.input_concat
        self.zero_phase = opt.zero_phase
        self.measurement_type = opt.measurement_type
        self.classical_recurrent = opt.classical_recurrent
        self.quantum_recurrent = opt.quantum_recurrent

        if self.input_concat:
            self.projections = nn.ModuleList(
                [nn.Linear(self.total_input_dim, self.embed_dim)])

        elif self.classical_recurrent:
            self.projections = nn.ModuleList(
                [nn.GRU(dim, self.embed_dim, 1) for dim in self.input_dims])
        else:
            self.projections = nn.ModuleList(
                [nn.Linear(dim, self.embed_dim) for dim in self.input_dims])

        self.multiply = ComplexMultiply()
        self.outer = QOuter()
        self.norm = L2Norm(dim=-1)
        self.mixture = QMixture(device=self.device)
        self.output_cell_dim = opt.output_cell_dim
        self.phase_embeddings = nn.ModuleList([PositionEmbedding(self.embed_dim, input_dim = self.speaker_num,\
                                                                 zero_phase = self.zero_phase, device = self.device)]*len(self.input_dims))

        self.out_dropout_rate = opt.out_dropout_rate
        self.num_layers = opt.num_layers
        self.recurrent_cells = nn.ModuleList(
            [QRNNCell(self.embed_dim, device=self.device)] * self.num_layers)
        self.out_dropout = QDropout(p=self.out_dropout_rate)
        self.activation = QActivation(scale_factor=1, beta=0.8)

        self.measurement_type = opt.measurement_type
        if self.measurement_type == 'quantum':
            self.measurement = QMeasurement(self.embed_dim)
            self.fc_out = SimpleNet(self.embed_dim,
                                    self.output_cell_dim,
                                    self.out_dropout_rate,
                                    self.n_classes,
                                    output_activation=nn.Tanh())

        elif self.measurement_type == 'flatten':
            self.fc_out = SimpleNet(self.embed_dim,
                                    self.output_cell_dim,
                                    self.out_dropout_rate,
                                    self.n_classes,
                                    output_activation=nn.Tanh())
        else:
            self.measurement = ComplexMeasurement(self.embed_dim,
                                                  units=self.embed_dim)
            self.fc_out = SimpleNet(self.embed_dim,
                                    self.output_cell_dim,
                                    self.out_dropout_rate,
                                    self.n_classes,
                                    output_activation=nn.Tanh())

    def get_params(self):

        unitary_params = []
        remaining_params = []
        for i in range(self.num_layers):
            unitary_params.append(self.recurrent_cells[i].unitary_x)
            unitary_params.append(self.recurrent_cells[i].unitary_h)
            remaining_params.append(self.recurrent_cells[i].Lambda)

        remaining_params.extend(list(self.projections.parameters()))

        remaining_params.extend(list(self.phase_embeddings.parameters()))
        for i in range(self.num_layers):
            remaining_params.append(self.recurrent_cells[i].Lambda)

        if self.measurement_type == 'quantum':
            unitary_params.extend(list(self.measurement.parameters()))
        else:
            remaining_params.extend(list(self.measurement.parameters()))
        remaining_params.extend(list(self.fc_out.parameters()))

        return unitary_params, remaining_params

    def forward(self, in_modalities):
        smask = in_modalities[-2]  # Speaker ids
        in_modalities = in_modalities[:-2]

        batch_size = in_modalities[0].shape[0]
        time_stamps = in_modalities[0].shape[1]

        # Project All modalities of each utterance to the same space
        #utterance_reps = [nn.Tanh()(projection(x)) for x, projection in zip(in_modalities,self.projections)]
        if self.input_concat:
            in_modalities = torch.cat(in_modalities, dim=-1)

        if self.classical_recurrent:
            utterance_reps = [
                nn.ReLU()(projection(x)[0])
                for x, projection in zip(in_modalities, self.projections)
            ]
        else:
            utterance_reps = [
                nn.ReLU()(projection(x))
                for x, projection in zip(in_modalities, self.projections)
            ]

        # Take the amplitudes
        # multiply with modality specific vectors to construct weights
        weights = [self.norm(rep) for rep in utterance_reps]
        weights = F.softmax(torch.cat(weights, dim=-1), dim=-1)

        amplitudes = [F.normalize(rep, dim=-1) for rep in utterance_reps]
        phases = [
            phase_embed(smask.argmax(dim=-1))
            for phase_embed in self.phase_embeddings
        ]
        unimodal_pure = [
            self.multiply([phase, amplitude])
            for phase, amplitude in zip(phases, amplitudes)
        ]
        unimodal_matrices = [self.outer(s) for s in unimodal_pure]
        in_states = self.mixture([unimodal_matrices, weights])

        if self.quantum_recurrent:
            for l in range(self.num_layers):
                # Initialize the cell h
                h_r = torch.stack(batch_size *
                                  [torch.eye(self.embed_dim) / self.embed_dim],
                                  dim=0)
                h_i = torch.zeros_like(h_r)
                h = [h_r.to(self.device), h_i.to(self.device)]
                all_h = []
                for t in range(time_stamps):
                    h = self.recurrent_cells[l](in_states[t], h)
                    all_h.append(self.activation(h))
                in_states = all_h

        output = []

        for _h in in_states:
            #            _h = self.out_dropout(_h)
            #            _h = self.dense(_h)
            #            measurement_probs = self.measurement(_h)
            if self.measurement_type == 'flatten':
                _output = self.fc_out(_h[0].reshape(batch_size, -1))

            else:
                measurement_probs = self.measurement(_h)
                _output = self.fc_out(measurement_probs)
            output.append(_output)

        output = torch.stack(output, dim=-2)
        log_prob = F.log_softmax(output, 2)  # batch, seq_len,  n_classes

        return log_prob
Exemplo n.º 5
0
class QMultiTask(nn.Module):
    def __init__(self, opt):
        super(QMultiTask, self).__init__()
        self.device = opt.device
        self.input_dims = opt.input_dims
        self.total_input_dim = sum(self.input_dims)
        self.embed_dim = opt.embed_dim
        self.speaker_num = opt.speaker_num
        self.dataset_name = opt.dataset_name
        self.features = opt.features

        # MELD data
        # The one-hot vectors are not the global user ID
        if self.dataset_name.lower() == 'meld':
            self.speaker_num = 1
        self.n_classes_emo = opt.output_dim_emo
        self.n_classes_act = opt.output_dim_act

        self.projections = nn.ModuleList(
            [nn.Linear(dim, self.embed_dim) for dim in self.input_dims])

        self.multiply = ComplexMultiply()
        self.outer = QOuter()
        self.norm = L2Norm(dim=-1)
        self.mixture = QMixture(device=self.device)
        self.output_cell_dim = opt.output_cell_dim
        self.phase_embeddings = nn.ModuleList([
            PositionEmbedding(
                self.embed_dim, input_dim=self.speaker_num, device=self.device)
        ] * len(self.input_dims))
        self.out_dropout_rate = opt.out_dropout_rate

        self.measurement_emotion = QMeasurement(self.embed_dim)
        self.measurement_act = QMeasurement(self.embed_dim)

        self.fc_out_emo = SimpleNet(self.embed_dim,
                                    self.output_cell_dim,
                                    self.out_dropout_rate,
                                    self.n_classes_emo,
                                    output_activation=nn.Tanh())

        self.fc_out_act = SimpleNet(self.embed_dim,
                                    self.output_cell_dim,
                                    self.out_dropout_rate,
                                    self.n_classes_act,
                                    output_activation=nn.Tanh())

        self.num_layers = opt.num_layers
        #self.rnn=nn.ModuleList([QRNNCell(self.embed_dim, device = self.device)]*self.num_layers)
        self.RNNs = nn.ModuleList([
            QRNN(self.embed_dim, self.device, self.num_layers)
            for i in range(len(opt.features))
        ])
        self.rnn_outer = QOuter()

        self.action_qrnn = QRNN(self.embed_dim, self.device, self.num_layers)

    def get_params(self):

        unitary_params = []
        remaining_params = []

        for i in range(len(self.features)):
            qrnn = self.RNNs[i]
            for k in range(self.num_layers):
                unitary_params.append(qrnn.recurrent_cells[k].unitary_x)
                unitary_params.append(qrnn.recurrent_cells[k].unitary_h)
                remaining_params.append(qrnn.recurrent_cells[k].Lambda)

        for k in range(self.num_layers):
            unitary_params.append(
                self.action_qrnn.recurrent_cells[k].unitary_x)
            unitary_params.append(
                self.action_qrnn.recurrent_cells[k].unitary_h)
            remaining_params.append(self.action_qrnn.recurrent_cells[k].Lambda)

        unitary_params.extend(list(self.measurement_act.parameters()))
        unitary_params.extend(list(self.measurement_emotion.parameters()))

        remaining_params.extend(list(self.projections.parameters()))
        remaining_params.extend(list(self.phase_embeddings.parameters()))

        remaining_params.extend(list(self.fc_out_act.parameters()))
        remaining_params.extend(list(self.fc_out_emo.parameters()))

        return unitary_params, remaining_params

    def forward(self, in_modalities):
        smask = in_modalities[-2]  # Speaker ids
        in_modalities = in_modalities[:-2]

        batch_size = in_modalities[0].shape[0]
        time_stamps = in_modalities[0].shape[1]
        # Project All modalities of each utterance to the same space
        utterance_reps = [
            nn.ReLU()(projection(x))
            for x, projection in zip(in_modalities, self.projections)
        ]  ##

        # Take the amplitudes
        # multiply with modality specific vectors to construct weights
        amplitudes = [F.normalize(rep, dim=-1) for rep in utterance_reps]
        phases = [
            phase_embed(smask.argmax(dim=-1))
            for phase_embed in self.phase_embeddings
        ]

        weights = [self.norm(rep) for rep in utterance_reps]
        weights = F.softmax(torch.cat(weights, dim=-1), dim=-1)

        unimodal_pure = [
            self.multiply([phase, amplitude])
            for phase, amplitude in zip(phases, amplitudes)
        ]
        unimodal_matrices = [self.outer(s) for s in unimodal_pure]

        rnn_unimodal_data = [
            rnn(data) for data, rnn in zip(unimodal_matrices, self.RNNs)
        ]  ##
        #         weights = [self.norm(rep) for rep in rnn_unimodal_data]
        #         weights = F.softmax(torch.cat(weights, dim = -1), dim = -1)
        emo_states = self.mixture([rnn_unimodal_data, weights])

        action_states = self.action_qrnn(emo_states)

        ###emotion classifier###
        output_emo = []
        for _h in emo_states:
            measurement_probs = self.measurement_emotion(_h)
            _output = self.fc_out_emo(measurement_probs)
            output_emo.append(_output)
        output_e = torch.stack(output_emo, dim=-2)
        log_prob_e = F.log_softmax(output_e, 2)  # batch, seq_len, n_classes

        ###action classifier###
        output_act = []
        for _h in emo_states:
            measurement_probs = self.measurement_act(_h)
            _output = self.fc_out_act(measurement_probs)
            output_act.append(_output)

        output_a = torch.stack(output_act, dim=-2)
        log_prob_a = F.log_softmax(output_a, 2)  # batch, seq_len,  n_classes

        return log_prob_e, log_prob_a
Exemplo n.º 6
0
def main(params, config):
    dataset = PedestalDataset(config)

    train_loader, validation_loader = split_dataset(dataset,
                                                    params['batch_size'])

    if config['experiment']['load_model'] != None:
        PATH = config['experiment']['load_model']
        checkpoint = torch.load(PATH)
        # Load Model
        net = SimpleNet(params, config)
        net.load_state_dict(checkpoint['model_state_dict'])

        # Load Optimizer
        optimizer = model_utils.map_optimizer(params['optimizer'],
                                              net.parameters(), 0.0)
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

        # Assign Loss Function
        loss_func = model_utils.map_loss_func(params['loss'])

        # Set EPOCH and LOSS for retraining
        epoch = checkpoint['epoch']
        loss = checkpoint['loss']

    else:
        net = SimpleNet(params, config)
        optimizer = model_utils.map_optimizer(params['optimizer'],
                                              net.parameters(),
                                              params['learning_rate'])
        loss_func = model_utils.map_loss_func(params['loss'])

    metrics = {}
    target_norms = dataset.targets_norms
    input_norms = dataset.inputs_norms

    target_list = dataset.target_params
    input_list = dataset.input_params

    net.eval()

    outputs = []
    actual_array = []
    scaled_list = []

    save_path = config['experiment']['name']

    for i, batch in enumerate(validation_loader):
        inputs, targets = batch['input'], batch['target']
        for val in inputs:
            output = net(val).detach().numpy()
            output = denormalize(output, target_list, target_norms)
            outputs.append(output[0])

            normed_vals = denormalize(val.numpy(), input_list, input_norms)

        for tar in targets:
            denorm_targ = denormalize(tar.numpy(), target_list, target_norms)
            actual_array.append(denorm_targ[0])

    if config['experiment']['target'] == 'density':
        for i, batch in enumerate(validation_loader):
            inputs, targets = batch['input'], batch['target']
            for val in inputs:
                normed_vals = denormalize(val.numpy(), input_list, input_norms)
                scaled_vals = fitted_scale(normed_vals)
                scaled_list.append(scaled_vals)
        plt.scatter(actual_array, scaled_list, label='Scale Law')

    plt.scatter(actual_array, actual_array, label='Actual')
    plt.scatter(actual_array, outputs, label='NN')
    plt.legend()
    plt.ylabel('Predicted')
    plt.xlabel('Actual Density Height')
    plt.ylim(0, 12)

    plt.title('Neural Network vs Scaling Law')
    plt.savefig('./results/' + save_path)
    plt.show()