Пример #1
0
def test(net, dataloader, tag=''):
    correct = 0
    total = 0
    if tag == 'Train':
        dataTestLoader = dataloader.trainloader
    else:
        dataTestLoader = dataloader.testloader
    with torch.no_grad():
        for data in dataTestLoader:
            images, labels = data
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    net.log('%s Accuracy of the network: %d %%' % (tag,
        100 * correct / total))

    class_correct = list(0. for i in range(10))
    class_total = list(0. for i in range(10))
    with torch.no_grad():
        for data in dataTestLoader:
            images, labels = data
            outputs = net(images)
            _, predicted = torch.max(outputs, 1)
            c = (predicted == labels).squeeze()
            for i in range(len(labels)):
                label = labels[i]
                class_correct[label] += c[i].item()
                class_total[label] += 1


    for i in range(10):
        net.log('%s Accuracy of %5s : %2d %%' % (
            tag, dataloader.classes[i], 100 * class_correct[i] / class_total[i]))
Пример #2
0
    def forward(self, vocab):
        with torch.no_grad():
            batch_shape = vocab['sentence'].shape
            s_embedding = self.embedding(vocab['sentence'].cuda())
            a_embedding = self.embedding(vocab['aspect'].cuda())

            packed_s = pack_padded_sequence(s_embedding, vocab['sent_len'], batch_first=True)

        out_s, (h_s, c1) = self.lstm_s(packed_s) # packed output
        out_a, (h_a, c2) = self.lstm_a(a_embedding)

        with torch.no_grad():
            unpacked_out_s, _ = pad_packed_sequence(out_s, batch_first=True)

        # Pair-wise interaction matrix
        I_matrix = torch.bmm(unpacked_out_s, out_a.permute(0,2,1))

        # Column-wise softmax
        a2s_attn = F.softmax(I_matrix, dim=1)

        # Row-wise softmax => Column-wise average => aspect attention
        s2a_attn = F.softmax(I_matrix, dim=2)
        a_attn = torch.mean(s2a_attn, dim=1)

        # Final sentence attn => weighted sum of each individual a2s_attn
        s_attn = torch.bmm(a2s_attn, a_attn.unsqueeze(-1))

        final_rep = torch.bmm(unpacked_out_s.permute(0,2,1), s_attn).squeeze(-1)
        pred = self.fc(final_rep)
        return pred
Пример #3
0
 def predict_proba(self,X):
     X = X.to(device =self.cf_a.device )
     
     if (self.cf_a.task_type == "regression"):
         with torch.no_grad():
             return self.forward(X)
     elif(self.cf_a.task_type == "classification"):
         with torch.no_grad():
             return  nn.functional.softmax(self.forward(X), dim = 1)
Пример #4
0
 def predict(self, X):
     """ sklearn interface without creating graph """
     X = X.to(device =self.cf_a.device )
     if (self.cf_a.task_type == "regression"):
         with torch.no_grad():
             return self.forward(X)
     elif(self.cf_a.task_type == "classification"):
         with torch.no_grad():
             return torch.argmax(self.forward(X),1)
Пример #5
0
 def forward(self, encoder_output, hsz, beam_width=1):
     h_i = self.get_state(encoder_output)
     context = encoder_output.output
     if beam_width > 1:
         with torch.no_grad():
             context = repeat_batch(context, beam_width)
             if type(h_i) is tuple:
                 h_i = repeat_batch(h_i[0], beam_width, dim=1), repeat_batch(h_i[1], beam_width, dim=1)
             else:
                 h_i = repeat_batch(h_i, beam_width, dim=1)
     batch_size = context.shape[0]
     h_size = (batch_size, hsz)
     with torch.no_grad():
         init_zeros = context.data.new(*h_size).zero_()
     return h_i, init_zeros, context
Пример #6
0
def stylize(args):
    device = torch.device("cuda" if args.cuda else "cpu")

    content_image = utils.load_image(args.content_image, scale=args.content_scale)
    content_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])
    content_image = content_transform(content_image)
    content_image = content_image.unsqueeze(0).to(device)

    if args.model.endswith(".onnx"):
        output = stylize_onnx_caffe2(content_image, args)
    else:
        with torch.no_grad():
            style_model = TransformerNet()
            state_dict = torch.load(args.model)
            # remove saved deprecated running_* keys in InstanceNorm from the checkpoint
            for k in list(state_dict.keys()):
                if re.search(r'in\d+\.running_(mean|var)$', k):
                    del state_dict[k]
            style_model.load_state_dict(state_dict)
            style_model.to(device)
            if args.export_onnx:
                assert args.export_onnx.endswith(".onnx"), "Export model file should end with .onnx"
                output = torch.onnx._export(style_model, content_image, args.export_onnx).cpu()
            else:
                output = style_model(content_image).cpu()
    utils.save_image(args.output_image, output[0])
Пример #7
0
def generate_translation(encoder, decoder, sentence, max_length, search="greedy", k = None):
    """ 
    @param max_length: the max # of words that the decoder can return
    @returns decoded_words: a list of words in target language
    """    
    with torch.no_grad():
        input_tensor = tensorFromSentence(input_lang, sentence)
        input_length = input_tensor.size()[0]
        
        # encode the source sentence
        encoder_hidden = encoder.initHidden()
        for ei in range(input_length):
            encoder_output, encoder_hidden = encoder(input_tensor[ei],
                                                     encoder_hidden)

        # start decoding
        decoder_input = torch.tensor([[SOS_token]], device=device)  # SOS
        decoder_hidden = encoder_hidden
        decoded_words = []
        
        if search == 'greedy':
            decoded_words = greedy_search(decoder, decoder_input, decoder_hidden, max_length)
        elif search == 'beam':
            if k == None:
                k = 2
            decoded_words = beam_search(decoder, decoder_input, decoder_hidden, max_length, k)  

        return decoded_words
Пример #8
0
    def test(self, evaluation=True):

        self.model.eval()
        loader = self.data_loader['test']
        loss_value = []
        result_frag = []
        label_frag = []

        for data, label in loader:
            
            # get data
            data = data.float().to(self.dev)
            label = label.long().to(self.dev)

            # inference
            with torch.no_grad():
                output = self.model(data)
            result_frag.append(output.data.cpu().numpy())

            # get loss
            if evaluation:
                loss = self.loss(output, label)
                loss_value.append(loss.item())
                label_frag.append(label.data.cpu().numpy())

        self.result = np.concatenate(result_frag)
        if evaluation:
            self.label = np.concatenate(label_frag)
            self.epoch_info['mean_loss']= np.mean(loss_value)
            self.show_epoch_info()

            # show top-k accuracy
            for k in self.arg.show_topk:
                self.show_topk(k)
Пример #9
0
def test(model, device, test_loader):
    model.to(device)
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        y_pred = []
        y_true = []
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            output = torch.mean(output.view(output.size(0), output.size(1), -1), dim=2)
            test_loss += F.cross_entropy(output, target)
            output = F.softmax(output, dim=1)
            confidence, pred = output.max(1)
            print('confidence: {}, prediction: {}, ground truth: {}'.format(confidence.cpu().numpy(), pred.cpu().numpy(), target.cpu().numpy()))
            y_pred += pred.data.tolist()
            y_true += target.data.tolist()
            correct += pred.eq(target.view_as(pred)).sum().item()

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))
    print(metrics.classification_report(np.asarray(y_true), np.asarray(y_pred)))
    print('confusion matrix: \n', metrics.confusion_matrix(np.asarray(y_true), np.asarray(y_pred)))
    print('\n')
Пример #10
0
def perform_val(multi_gpu, device, embedding_size, batch_size, backbone, carray, issame, nrof_folds = 10, tta = True):
    if multi_gpu:
        backbone = backbone.module # unpackage model from DataParallel
        backbone = backbone.to(device)
    else:
        backbone = backbone.to(device)
    backbone.eval() # switch to evaluation mode

    idx = 0
    embeddings = np.zeros([len(carray), embedding_size])
    with torch.no_grad():
        while idx + batch_size <= len(carray):
            batch = torch.tensor(carray[idx:idx + batch_size][:, [2, 1, 0], :, :])
            if tta:
                fliped = hflip_batch(batch)
                emb_batch = backbone(batch.to(device)).cpu() + backbone(fliped.to(device)).cpu()
                embeddings[idx:idx + batch_size] = l2_norm(emb_batch)
            else:
                embeddings[idx:idx + batch_size] = backbone(batch.to(device)).cpu()
            idx += batch_size
        if idx < len(carray):
            batch = torch.tensor(carray[idx:])
            if tta:
                fliped = hflip_batch(batch)
                emb_batch = backbone(batch.to(device)).cpu() + backbone(fliped.to(device)).cpu()
                embeddings[idx:] = l2_norm(emb_batch)
            else:
                embeddings[idx:] = backbone(batch.to(device)).cpu()

    tpr, fpr, accuracy, best_thresholds = evaluate(embeddings, issame, nrof_folds)
    buf = gen_plot(fpr, tpr)
    roc_curve = Image.open(buf)
    roc_curve_tensor = transforms.ToTensor()(roc_curve)

    return accuracy.mean(), best_thresholds.mean(), roc_curve_tensor
def generate_translation(encoder, decoder, sentence, max_length, target_lang, search="greedy", k = None):
    """ 
    @param max_length: the max # of words that the decoder can return
    @returns decoded_words: a list of words in target language
    """    
    with torch.no_grad():
        input_tensor = sentence
        input_length = sentence.size()[1]
        
        # encode the source sentence
        encoder_hidden = encoder.init_hidden(1)
        # input_tensor 1 by 12 
        # 
        encoder_outputs, encoder_hidden = encoder(input_tensor.view(1, -1),torch.tensor([input_length]))
        # start decoding
        decoder_input = torch.tensor([[SOS_token]], device=device)  # SOS
        decoder_hidden = encoder_hidden
        decoded_words = []
        
        if search == 'greedy':
            decoded_words = greedy_search_batch(decoder, decoder_input, encoder_outputs, decoder_hidden, max_length)
        elif search == 'beam':
            if k == None:
                k = 2 # since k = 2 preforms badly
            decoded_words = beam_search(decoder, decoder_input, encoder_outputs, decoder_hidden, max_length, k, target_lang) 

        return decoded_words
Пример #12
0
def kaiming_normal_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu'):
    r"""Fills the input `Tensor` with values according to the method
    described in "Delving deep into rectifiers: Surpassing human-level
    performance on ImageNet classification" - He, K. et al. (2015), using a
    normal distribution. The resulting tensor will have values sampled from
    :math:`\mathcal{N}(0, \text{std})` where

    .. math::
        \text{std} = \sqrt{\frac{2}{(1 + a^2) \times \text{fan_in}}}

    Also known as He initialization.

    Args:
        tensor: an n-dimensional `torch.Tensor`
        a: the negative slope of the rectifier used after this layer (0 for ReLU
            by default)
        mode: either 'fan_in' (default) or 'fan_out'. Choosing `fan_in`
            preserves the magnitude of the variance of the weights in the
            forward pass. Choosing `fan_out` preserves the magnitudes in the
            backwards pass.
        nonlinearity: the non-linear function (`nn.functional` name),
            recommended to use only with 'relu' or 'leaky_relu' (default).

    Examples:
        >>> w = torch.empty(3, 5)
        >>> nn.init.kaiming_normal_(w, mode='fan_out', nonlinearity='relu')
    """
    fan = _calculate_correct_fan(tensor, mode)
    gain = calculate_gain(nonlinearity, a)
    std = gain / math.sqrt(fan)
    with torch.no_grad():
        return tensor.normal_(0, std)
Пример #13
0
def sparse_(tensor, sparsity, std=0.01):
    r"""Fills the 2D input `Tensor` as a sparse matrix, where the
    non-zero elements will be drawn from the normal distribution
    :math:`\mathcal{N}(0, 0.01)`, as described in "Deep learning via
    Hessian-free optimization" - Martens, J. (2010).

    Args:
        tensor: an n-dimensional `torch.Tensor`
        sparsity: The fraction of elements in each column to be set to zero
        std: the standard deviation of the normal distribution used to generate
            the non-zero values

    Examples:
        >>> w = torch.empty(3, 5)
        >>> nn.init.sparse_(w, sparsity=0.1)
    """
    if tensor.ndimension() != 2:
        raise ValueError("Only tensors with 2 dimensions are supported")

    rows, cols = tensor.shape
    num_zeros = int(math.ceil(sparsity * rows))

    with torch.no_grad():
        tensor.normal_(0, std)
        for col_idx in range(cols):
            row_indices = torch.randperm(rows)
            zero_indices = row_indices[:num_zeros]
            tensor[zero_indices, col_idx] = 0
    return tensor
Пример #14
0
def dirac_(tensor):
    r"""Fills the {3, 4, 5}-dimensional input `Tensor` with the Dirac
    delta function. Preserves the identity of the inputs in `Convolutional`
    layers, where as many input channels are preserved as possible.

    Args:
        tensor: a {3, 4, 5}-dimensional `torch.Tensor`

    Examples:
        >>> w = torch.empty(3, 16, 5, 5)
        >>> nn.init.dirac_(w)
    """
    dimensions = tensor.ndimension()
    if dimensions not in [3, 4, 5]:
        raise ValueError("Only tensors with 3, 4, or 5 dimensions are supported")

    sizes = tensor.size()
    min_dim = min(sizes[0], sizes[1])
    with torch.no_grad():
        tensor.zero_()

        for d in range(min_dim):
            if dimensions == 3:  # Temporal convolution
                tensor[d, d, tensor.size(2) // 2] = 1
            elif dimensions == 4:  # Spatial convolution
                tensor[d, d, tensor.size(2) // 2, tensor.size(3) // 2] = 1
            else:  # Volumetric convolution
                tensor[d, d, tensor.size(2) // 2, tensor.size(3) // 2, tensor.size(4) // 2] = 1
    return tensor
Пример #15
0
def xavier_uniform_(tensor, gain=1):
    r"""Fills the input `Tensor` with values according to the method
    described in "Understanding the difficulty of training deep feedforward
    neural networks" - Glorot, X. & Bengio, Y. (2010), using a uniform
    distribution. The resulting tensor will have values sampled from
    :math:`\mathcal{U}(-a, a)` where

    .. math::
        a = \text{gain} \times \sqrt{\frac{6}{\text{fan_in} + \text{fan_out}}}

    Also known as Glorot initialization.

    Args:
        tensor: an n-dimensional `torch.Tensor`
        gain: an optional scaling factor

    Examples:
        >>> w = torch.empty(3, 5)
        >>> nn.init.xavier_uniform_(w, gain=nn.init.calculate_gain('relu'))
    """
    fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
    std = gain * math.sqrt(2.0 / (fan_in + fan_out))
    a = math.sqrt(3.0) * std  # Calculate uniform bounds from standard deviation
    with torch.no_grad():
        return tensor.uniform_(-a, a)
Пример #16
0
def evaluate(model: Model,
             instances: Iterable[Instance],
             data_iterator: DataIterator,
             cuda_device: int) -> Dict[str, Any]:
    _warned_tqdm_ignores_underscores = False
    check_for_gpu(cuda_device)
    with torch.no_grad():
        model.eval()

        iterator = data_iterator(instances,
                                 num_epochs=1,
                                 shuffle=False)
        logger.info("Iterating over dataset")
        generator_tqdm = Tqdm.tqdm(iterator, total=data_iterator.get_num_batches(instances))
        for batch in generator_tqdm:
            batch = util.move_to_device(batch, cuda_device)
            model(**batch)
            metrics = model.get_metrics()
            if (not _warned_tqdm_ignores_underscores and
                        any(metric_name.startswith("_") for metric_name in metrics)):
                logger.warning("Metrics with names beginning with \"_\" will "
                               "not be logged to the tqdm progress bar.")
                _warned_tqdm_ignores_underscores = True
            description = ', '.join(["%s: %.2f" % (name, value) for name, value
                                     in metrics.items() if not name.startswith("_")]) + " ||"
            generator_tqdm.set_description(description, refresh=False)

        return model.get_metrics(reset=True)
Пример #17
0
def kaiming_normal(tensor, a=0, mode='fan_in'):
    """Fills the input Tensor or Variable with values according to the method
    described in "Delving deep into rectifiers: Surpassing human-level
    performance on ImageNet classification" - He, K. et al. (2015), using a
    normal distribution. The resulting tensor will have values sampled from
    :math:`N(0, std)` where
    :math:`std = \sqrt{2 / ((1 + a^2) \\times fan\_in)}`. Also known as He
    initialisation.

    Args:
        tensor: an n-dimensional torch.Tensor or autograd.Variable
        a: the negative slope of the rectifier used after this layer (0 for ReLU
            by default)
        mode: either 'fan_in' (default) or 'fan_out'. Choosing `fan_in`
            preserves the magnitude of the variance of the weights in the
            forward pass. Choosing `fan_out` preserves the magnitudes in the
            backwards pass.

    Examples:
        >>> w = torch.Tensor(3, 5)
        >>> nn.init.kaiming_normal(w, mode='fan_out')
    """
    fan = _calculate_correct_fan(tensor, mode)
    gain = calculate_gain('leaky_relu', a)
    std = gain / math.sqrt(fan)
    with torch.no_grad():
        return tensor.normal_(0, std)
def resnet_features(batch_arrayd):

    with torch.no_grad():

        batch_feature = {}
        ids = list(batch_arrayd.keys())
        video_array = [x for x in batch_arrayd.values()]
        array_sizes = [x.shape[0] for x in batch_arrayd.values()]

        video1_array = np.array(video_array[0], dtype = np.float32)  # change datatype of frames to float32
        video_tensor = torch.from_numpy(video1_array)

        video_frames = video_tensor.size()[0]
        num_steps = math.ceil(video_frames / 100)
        resnet_feature = torch.zeros(video_frames,2048)

        video_tensor = video_tensor.permute(0,3,1,2) # change dimension to [?,3,224,224]

        for i in range(num_steps):
            start = i*100
            end = min((i+1)*100, video_frames)
            tensor_var = Variable(video_tensor[start:end]).to(device)
            feature = resnet50(tensor_var).data
            feature.squeeze_(3)
            feature.squeeze_(2)
            resnet_feature[start:end] = feature

        return {ids[0]:resnet_feature}
def evaluate(encoder, decoder, sentence, max_length=MAX_LENGTH):
    with torch.no_grad():
        input_tensor = tensorFromSentence(input_lang, sentence)
        input_length = input_tensor.size()[0]
        encoder_hidden = encoder.initHidden()

        encoder_outputs = torch.zeros(max_length, encoder.hidden_size, device=device)

        for ei in range(input_length):
            encoder_output, encoder_hidden = encoder(input_tensor[ei],
                                                     encoder_hidden)
            encoder_outputs[ei] += encoder_output[0, 0]

        decoder_input = torch.tensor([[SOS_token]], device=device)  # SOS

        decoder_hidden = encoder_hidden

        decoded_words = []
        decoder_attentions = torch.zeros(max_length, max_length)

        for di in range(max_length):
            decoder_output, decoder_hidden, decoder_attention = decoder(
                decoder_input, decoder_hidden, encoder_outputs)
            decoder_attentions[di] = decoder_attention.data
            topv, topi = decoder_output.data.topk(1)
            if topi.item() == EOS_token:
                decoded_words.append('<EOS>')
                break
            else:
                decoded_words.append(output_lang.index2word[topi.item()])

            decoder_input = topi.squeeze().detach()

        return decoded_words, decoder_attentions[:di + 1]
    def segment(self, image):
        # don't track tensors with autograd during prediction
        with torch.no_grad():
            mean, std = self.dataset['stats']['mean'], self.dataset['stats']['std']

            transform = Compose([
                ConvertImageMode(mode='RGB'),
                ImageToTensor(),
                Normalize(mean=mean, std=std)
            ])
            image = transform(image)

            batch = image.unsqueeze(0).to(self.device)

            output = self.net(batch)

            output = output.cpu().data.numpy()
            output = output.squeeze(0)

            mask = output.argmax(axis=0).astype(np.uint8)

            mask = Image.fromarray(mask, mode='P')

            palette = make_palette(*self.dataset['common']['colors'])
            mask.putpalette(palette)

            return mask
Пример #21
0
 def sample(self, sample_shape=torch.Size()):
     """
     Generates a sample_shape shaped sample or sample_shape shaped batch of
     samples if the distribution parameters are batched.
     """
     with torch.no_grad():
         return self.rsample(sample_shape)
Пример #22
0
 def train_actor(self, batch):
     '''Trains the actor when the actor and critic are separate networks'''
     with torch.no_grad():
         advs, _v_targets = self.calc_advs_v_targets(batch)
     policy_loss = self.calc_policy_loss(batch, advs)
     self.net.training_step(loss=policy_loss)
     return policy_loss
Пример #23
0
    def train_batch(self,  # type: ignore
                question: Dict[str, torch.LongTensor],
                passage: Dict[str, torch.LongTensor],
                span_start: torch.IntTensor = None,
                span_end: torch.IntTensor = None,
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        
        """
        It is enough to just compute the total loss because the normal weights 
        do not depend on the KL Divergence
        """
        # Now we can just compute both losses which will build the dynamic graph
        
        output = self.forward(question,passage,span_start,span_end,metadata )
        data_loss = output["loss"]
        
        KL_div = self.get_KL_divergence()
        total_loss =  self.combine_losses(data_loss, KL_div)
        
        self.zero_grad()     # zeroes the gradient buffers of all parameters
        total_loss.backward()
        
        if (type(self._optimizer) == type(None)):
            parameters = filter(lambda p: p.requires_grad, self.parameters())
            with torch.no_grad():
                for f in parameters:
                    f.data.sub_(f.grad.data * self.lr )
        else:
#            print ("Training")
            self._optimizer.step()
            self._optimizer.zero_grad()
            
        return output
Пример #24
0
def train(net):
    net.train()
    priorbox = PriorBox()
    with torch.no_grad():
        priors = priorbox.forward()
        priors = priors.to(device)

    dataloader = DataLoader(VOCDetection(), batch_size=2, collate_fn=detection_collate, num_workers=12)

    for epoch in range(1000):
        loss_ls, loss_cs = [], []
        load_t0 = time.time()
        if epoch > 500:
            adjust_learning_rate(optimizer, 1e-4)

        for images, targets in dataloader:
            images = images.to(device)
            targets = [anno.to(device) for anno in targets]
            out = net(images)
            optimizer.zero_grad()
            loss_l, loss_c = criterion(out, priors, targets)

            loss = 2 * loss_l + loss_c
            loss.backward()
            optimizer.step()
            loss_cs.append(loss_c.item())
            loss_ls.append(loss_l.item())
        load_t1 = time.time()

        print(f'{np.mean(loss_cs)}, {np.mean(loss_ls)} time:{load_t1-load_t0}')
        torch.save(net.state_dict(), 'Final_FaceBoxes.pth')
Пример #25
0
def tts(model, text, p=0, speaker_id=None, fast=True):
    """Convert text to speech waveform given a deepvoice3 model.

    Args:
        text (str) : Input text to be synthesized
        p (float) : Replace word to pronounciation if p > 0. Default is 0.
    """
    model = model.to(device)
    model.eval()
    if fast:
        model.make_generation_fast_()

    sequence = np.array(_frontend.text_to_sequence(text, p=p))
    sequence = torch.from_numpy(sequence).unsqueeze(0).long().to(device)
    text_positions = torch.arange(1, sequence.size(-1) + 1).unsqueeze(0).long().to(device)
    speaker_ids = None if speaker_id is None else torch.LongTensor([speaker_id]).to(device)

    # Greedy decoding
    with torch.no_grad():
        mel_outputs, linear_outputs, alignments, done = model(
            sequence, text_positions=text_positions, speaker_ids=speaker_ids)

    linear_output = linear_outputs[0].cpu().data.numpy()
    spectrogram = audio._denormalize(linear_output)
    alignment = alignments[0].cpu().data.numpy()
    mel = mel_outputs[0].cpu().data.numpy()
    mel = audio._denormalize(mel)

    # Predicted audio signal
    waveform = audio.inv_spectrogram(linear_output.T)

    return waveform, alignment, spectrogram, mel
def fit_norm_distribution_param(args, model, train_dataset, channel_idx=0):
    predictions = []
    organized = []
    errors = []
    with torch.no_grad():
        # Turn on evaluation mode which disables dropout.
        model.eval()
        pasthidden = model.init_hidden(1)
        for t in range(len(train_dataset)):
            out, hidden = model.forward(train_dataset[t].unsqueeze(0), pasthidden)
            predictions.append([])
            organized.append([])
            errors.append([])
            predictions[t].append(out.data.cpu()[0][0][channel_idx])
            pasthidden = model.repackage_hidden(hidden)
            for prediction_step in range(1,args.prediction_window_size):
                out, hidden = model.forward(out, hidden)
                predictions[t].append(out.data.cpu()[0][0][channel_idx])

            if t >= args.prediction_window_size:
                for step in range(args.prediction_window_size):
                    organized[t].append(predictions[step+t-args.prediction_window_size][args.prediction_window_size-1-step])
                organized[t]= torch.FloatTensor(organized[t]).to(args.device)
                errors[t] = organized[t] - train_dataset[t][0][channel_idx]
                errors[t] = errors[t].unsqueeze(0)

    errors_tensor = torch.cat(errors[args.prediction_window_size:],dim=0)
    mean = errors_tensor.mean(dim=0)
    cov = errors_tensor.t().mm(errors_tensor)/errors_tensor.size(0) - mean.unsqueeze(1).mm(mean.unsqueeze(0))
    # cov: positive-semidefinite and symmetric.

    return mean, cov
def stylize(args):
    device = torch.device("cuda" if args.cuda else "cpu")
    with torch.no_grad():
        style_model = TransformerNet()
        state_dict = torch.load(os.path.join(args.model_dir, args.style+".pth"))
        # remove saved deprecated running_* keys in InstanceNorm from the checkpoint
        for k in list(state_dict.keys()):
            if re.search(r'in\d+\.running_(mean|var)$', k):
                del state_dict[k]
        style_model.load_state_dict(state_dict)
        style_model.to(device)

        filenames = os.listdir(args.content_dir)

        for filename in filenames:
            print("Processing {}".format(filename))
            full_path = os.path.join(args.content_dir, filename)
            content_image = load_image(full_path, scale=args.content_scale)
            content_transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Lambda(lambda x: x.mul(255))
            ])
            content_image = content_transform(content_image)
            content_image = content_image.unsqueeze(0).to(device)

            output = style_model(content_image).cpu()

            output_path = os.path.join(args.output_dir, filename)
            save_image(output_path, output[0])
Пример #28
0
def predict(image_path, model, topk, architecture):   
    
    img = load_image(image_path)

    model.eval()   
    # set architecture (cuda or cpu)
    model.to(architecture)
    img = img.to(architecture)

    with torch.no_grad():
        output = model.forward(img)
        
    # get props
    probability = torch.exp(output.data)
    
    # get top k procs
    top_probs, top_labs = probability.topk(topk)

    # convert to numpy lists
    top_probs = top_probs.cpu().numpy()[0].tolist()
    top_labs = top_labs.cpu().numpy()[0].tolist()

    # reverse class_to_idx
    idx_to_class = {val: key for key, val in model.class_to_idx.items() }

    # map to classes from file and to string labels
    top_labels = [idx_to_class[label] for label in top_labs]
    top_flowers = [cat_to_name[idx_to_class[label]] for label in top_labs]

    return top_probs, top_labels, top_flowers
Пример #29
0
 def forward(self, images):
     """Extract feature vectors from input images."""
     with torch.no_grad():
         features = self.resnet(images)
     features = features.reshape(features.size(0), -1)
     features = self.bn(self.linear(features))
     return features
Пример #30
0
def main():
    args = get_arguments()

    os.environ["CUDA_VISIBLE_DEVICES"]=args.gpu

    model = XLSor(num_classes=args.num_classes)
    
    saved_state_dict = torch.load(args.restore_from)
    model.load_state_dict(saved_state_dict)

    model.eval()
    model.cuda()

    testloader = data.DataLoader(XRAYDataTestSet(args.data_dir, args.data_list, crop_size=(512, 512), mean=IMG_MEAN, scale=False, mirror=False), batch_size=1, shuffle=False, pin_memory=True)

    interp = nn.Upsample(size=(512, 512), mode='bilinear', align_corners=True)

    if not os.path.exists('outputs'):
        os.makedirs('outputs')

    for index, batch in enumerate(testloader):
        if index % 100 == 0:
            print('%d processd'%(index))
        image, size, name = batch
        with torch.no_grad():
            prediction = model(image.cuda(), args.recurrence)
            if isinstance(prediction, list):
                prediction = prediction[0]
            prediction = interp(prediction).cpu().data[0].numpy().transpose(1, 2, 0)
        output_im = PILImage.fromarray((np.clip(prediction[:,:,0],0,1)* 255).astype(np.uint8))
        output_im.save('./outputs/' + os.path.basename(name[0]).replace('.png', '_xlsor.png'), 'png')
    def test_pt_tf_model_equivalence(self):
        if not is_torch_available():
            return

        import torch
        import transformers

        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

        for model_class in self.all_model_classes:
            pt_model_class_name = model_class.__name__[2:]  # Skip the "TF" at the beggining
            pt_model_class = getattr(transformers, pt_model_class_name)

            config.output_hidden_states = True
            tf_model = model_class(config)
            pt_model = pt_model_class(config)

            # Check we can load pt model in tf and vice-versa with model => model functions
            tf_model = transformers.load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=inputs_dict)
            pt_model = transformers.load_tf2_model_in_pytorch_model(pt_model, tf_model)

            # Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences
            pt_model.eval()
            pt_inputs_dict = dict(
                (name, torch.from_numpy(key.numpy()).to(torch.long)) for name, key in inputs_dict.items()
            )
            with torch.no_grad():
                pto = pt_model(**pt_inputs_dict)
            tfo = tf_model(inputs_dict, training=False)
            tf_hidden_states = tfo[0].numpy()
            pt_hidden_states = pto[0].numpy()

            pt_hidden_states[np.isnan(tf_hidden_states)] = 0
            tf_hidden_states[np.isnan(tf_hidden_states)] = 0
            pt_hidden_states[np.isnan(pt_hidden_states)] = 0
            tf_hidden_states[np.isnan(pt_hidden_states)] = 0

            max_diff = np.amax(np.abs(tf_hidden_states - pt_hidden_states))
            # Debug info (remove when fixed)
            if max_diff >= 2e-2:
                print("===")
                print(model_class)
                print(config)
                print(inputs_dict)
                print(pt_inputs_dict)
            self.assertLessEqual(max_diff, 2e-2)

            # Check we can load pt model in tf and vice-versa with checkpoint => model functions
            with tempfile.TemporaryDirectory() as tmpdirname:
                pt_checkpoint_path = os.path.join(tmpdirname, "pt_model.bin")
                torch.save(pt_model.state_dict(), pt_checkpoint_path)
                tf_model = transformers.load_pytorch_checkpoint_in_tf2_model(tf_model, pt_checkpoint_path)

                tf_checkpoint_path = os.path.join(tmpdirname, "tf_model.h5")
                tf_model.save_weights(tf_checkpoint_path)
                pt_model = transformers.load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path)

            # Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences
            pt_model.eval()
            pt_inputs_dict = dict(
                (name, torch.from_numpy(key.numpy()).to(torch.long)) for name, key in inputs_dict.items()
            )
            with torch.no_grad():
                pto = pt_model(**pt_inputs_dict)
            tfo = tf_model(inputs_dict)
            tfo = tfo[0].numpy()
            pto = pto[0].numpy()
            tfo[np.isnan(tfo)] = 0
            pto[np.isnan(pto)] = 0
            max_diff = np.amax(np.abs(tfo - pto))
            self.assertLessEqual(max_diff, 2e-2)
def evaluate(args, model, tokenizer, mode, prefix=""):
    eval_task = args.task_name
    eval_output_dir = args.output_dir

    eval_dataset = load_and_cache_examples(args, eval_task, tokenizer, mode)

    if not os.path.exists(eval_output_dir) and args.local_rank in [-1, 0]:
        os.makedirs(eval_output_dir)

    args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
    # Note that DistributedSampler samples randomly
    eval_sampler = SequentialSampler(eval_dataset)
    eval_dataloader = DataLoader(eval_dataset,
                                 sampler=eval_sampler,
                                 batch_size=args.eval_batch_size)

    # multi-gpu eval
    if args.n_gpu > 1 and not isinstance(model, torch.nn.DataParallel):
        model = torch.nn.DataParallel(model)

    # Eval!
    logger.info("***** Running evaluation {} *****".format(prefix))
    logger.info("  Num examples = %d", len(eval_dataset))
    logger.info("  Batch size = %d", args.eval_batch_size)
    eval_loss = 0.0
    nb_eval_steps = 0
    preds = None
    out_label_ids = None
    for batch in tqdm(eval_dataloader, desc="Evaluating"):
        model.eval()
        batch = tuple(t.to(args.device) for t in batch)

        with torch.no_grad():
            inputs = {
                "input_ids": batch[0],
                "attention_mask": batch[1],
                "labels": batch[3]
            }
            if args.model_type != "distilbert":
                inputs["token_type_ids"] = (
                    batch[2]
                    if args.model_type in ["bert", "xlnet", "albert"] else None
                )  # XLM, DistilBERT, RoBERTa, and XLM-RoBERTa don't use segment_ids
            outputs = model(**inputs)
            tmp_eval_loss, logits = outputs[:2]

            logits = logits.sigmoid()  # for multi label classification
            eval_loss += tmp_eval_loss.mean().item()
        nb_eval_steps += 1
        if preds is None:
            preds = logits.detach().cpu().numpy()
            out_label_ids = inputs["labels"].detach().cpu().numpy()
        else:
            preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
            out_label_ids = np.append(out_label_ids,
                                      inputs["labels"].detach().cpu().numpy(),
                                      axis=0)

    eval_loss = eval_loss / nb_eval_steps

    # create DataFrame to compare y_pred with y_true.
    df = pd.read_csv(os.path.join(args.data_dir, "{}.tsv".format(mode)),
                     sep='\t')

    # compute F1-score for HoC
    threshold = 0.5
    y_pred = (preds > threshold).astype(int)
    df['pred_labels'] = list(
        map(
            lambda x: ','.join(['{}_{}'.format(i, v)
                                for i, v in enumerate(x)]), y_pred))

    result = eval_hoc(df, mode)

    # compute ROC-AUC for multi label classification
    roc_auc = eval_roc_auc(out_label_ids, preds, args.num_labels)
    result['micro_roc_auc'] = roc_auc['micro']

    result['loss'] = eval_loss

    output_eval_file = os.path.join(
        args.output_dir, args.result_prefix +
        "{}_results.txt".format(mode if mode != 'dev' else 'eval'))
    with open(output_eval_file, "w") as writer:
        logger.info("***** Eval results {} *****".format(prefix))
        for key in result.keys():
            logger.info("  %s = %s", key, str(result[key]))
            writer.write("%s = %s\n" % (key, str(result[key])))

    if args.output_all_logits:
        output_all_logit_file = os.path.join(
            args.output_dir,
            args.result_prefix + "{}_all_logits.txt".format(mode))
        with open(output_all_logit_file, "w") as writer:
            logger.info("***** Output all logits {} *****".format(prefix))
            for sample in preds:
                writer.write('\t'.join(['{:.3f}'.format(v)
                                        for v in sample]) + '\n')

        # output ROC curve and ROC area for each class
        output_roc_auc_file = os.path.join(
            args.output_dir,
            args.result_prefix + "{}_roc_auc_for_each_class.txt".format(mode))
        with open(output_roc_auc_file, "w") as writer:
            logger.info(
                "***** output ROC curve and ROC area for each class {} *****".
                format(prefix))
            for key in roc_auc.keys():
                logger.info("  %s = %s", key, str(roc_auc[key]))
                writer.write("%s = %s\n" % (key, str(roc_auc[key])))

    return result, preds, df[['index', 'labels', 'pred_labels']]
Пример #33
0
def multi_gpu_test(model, dataset, cfg, show=False, tmpdir=None):
    model.eval()
    results = []
    rank, world_size = get_dist_info()
    if rank == 0:
        prog_bar = mmcv.ProgressBar(len(dataset))

    for idx in range(rank, len(dataset), world_size):
        data = dataset[idx]

        # None type data cannot be scatter, here we pick out the not None type data
        notNoneData = {}
        for k, v in zip(data.keys(), data.values()):
            if v is not None:
                notNoneData[k] = v
        notNoneData = scatter(
            collate([notNoneData], samples_per_gpu=1),
            [torch.cuda.current_device()]
        )[0]

        data.update(notNoneData)

        # TODO: evaluate after generate all predictions!
        with torch.no_grad():
            result, _ = model(data)
            disps = result['disps']

            ori_size = data['original_size']
            disps = remove_padding(disps, ori_size)
            target = data['leftDisp'] if 'leftDisp' in data else None
            target = remove_padding(target, ori_size)
            error_dict = do_evaluation(
                disps[0], target, cfg.model.eval.lower_bound, cfg.model.eval.upper_bound)

            if cfg.model.eval.eval_occlusion and 'leftDisp' in data and 'rightDisp' in data:
                data['leftDisp'] = remove_padding(data['leftDisp'], ori_size)
                data['rightDisp'] = remove_padding(data['rightDisp'], ori_size)

                occ_error_dict = do_occlusion_evaluation(
                    disps[0], data['leftDisp'], data['rightDisp'],
                    cfg.model.eval.lower_bound, cfg.model.eval.upper_bound)
                error_dict.update(occ_error_dict)

            result = {
                'Disparity': disps,
                'GroundTruth': target,
                'Error': error_dict,
            }

        filter_result = {}
        filter_result.update(Error=result['Error'])

        if show:
            item = dataset.data_list[idx]
            result['leftImage'] = imread(
                osp.join(cfg.data.test.data_root, item['left_image_path'])
            ).astype(np.float32)
            result['rightImage'] = imread(
                osp.join(cfg.data.test.data_root, item['right_image_path'])
            ).astype(np.float32)
            image_name = item['left_image_path'].split('/')[-1]
            save_result(result, cfg.out_dir, image_name)

        if hasattr(cfg, 'sparsification_plot'):
            filter_result['Error'].update(sparsification_eval(result, cfg))

        results.append(filter_result)

        if rank == 0:
            batch_size = world_size
            for _ in range(batch_size):
                prog_bar.update()

    # collect results from all ranks
    results = collect_results(results, len(dataset), tmpdir)

    return results
Пример #34
0
    def train(self) -> Dict[str, Any]:
        """
        Trains the supplied model with the supplied parameters.
        """
        try:
            epoch_counter = self._restore_checkpoint()
        except RuntimeError:
            traceback.print_exc()
            raise ConfigurationError(
                "Could not recover training from the checkpoint.  Did you mean to output to "
                "a different serialization directory or delete the existing serialization "
                "directory?"
            )

        training_util.enable_gradient_clipping(self.model, self._grad_clipping)

        logger.info("Beginning training.")

        val_metrics: Dict[str, float] = {}
        this_epoch_val_metric: float = None
        metrics: Dict[str, Any] = {}
        epochs_trained = 0
        training_start_time = time.time()

        metrics["best_epoch"] = self._metric_tracker.best_epoch
        for key, value in self._metric_tracker.best_epoch_metrics.items():
            metrics["best_validation_" + key] = value

        for callback in self._epoch_callbacks:
            callback(self, metrics={}, epoch=-1, is_master=self._master)

        for epoch in range(epoch_counter, self._num_epochs):
            epoch_start_time = time.time()
            train_metrics = self._train_epoch(epoch)

            # get peak of memory usage
            for key, value in train_metrics.items():
                if key.startswith("gpu_") and key.endswith("_memory_MB"):
                    metrics["peak_" + key] = max(metrics.get("peak_" + key, 0), value)
                elif key.startswith("worker_") and key.endswith("_memory_MB"):
                    metrics["peak_" + key] = max(metrics.get("peak_" + key, 0), value)

            if self._validation_data_loader is not None:
                with torch.no_grad():
                    # We have a validation set, so compute all the metrics on it.
                    val_loss, val_reg_loss, num_batches = self._validation_loss(epoch)

                    # It is safe again to wait till the validation is done. This is
                    # important to get the metrics right.
                    if self._distributed:
                        dist.barrier()

                    val_metrics = training_util.get_metrics(
                        self.model,
                        val_loss,
                        val_reg_loss,
                        batch_loss=None,
                        batch_reg_loss=None,
                        num_batches=num_batches,
                        reset=True,
                        world_size=self._world_size,
                        cuda_device=self.cuda_device,
                    )

                    # Check validation metric for early stopping
                    this_epoch_val_metric = val_metrics[self._validation_metric]
                    self._metric_tracker.add_metric(this_epoch_val_metric)

                    if self._metric_tracker.should_stop_early():
                        logger.info("Ran out of patience.  Stopping training.")
                        break

            if self._master:
                self._tensorboard.log_metrics(
                    train_metrics, val_metrics=val_metrics, log_to_console=True, epoch=epoch + 1
                )  # +1 because tensorboard doesn't like 0

            # Create overall metrics dict
            training_elapsed_time = time.time() - training_start_time
            metrics["training_duration"] = str(datetime.timedelta(seconds=training_elapsed_time))
            metrics["training_start_epoch"] = epoch_counter
            metrics["training_epochs"] = epochs_trained
            metrics["epoch"] = epoch

            for key, value in train_metrics.items():
                metrics["training_" + key] = value
            for key, value in val_metrics.items():
                metrics["validation_" + key] = value

            if self._metric_tracker.is_best_so_far():
                # Update all the best_ metrics.
                # (Otherwise they just stay the same as they were.)
                metrics["best_epoch"] = epoch
                for key, value in val_metrics.items():
                    metrics["best_validation_" + key] = value

                self._metric_tracker.best_epoch_metrics = val_metrics

            if self._serialization_dir and self._master:
                common_util.dump_metrics(
                    os.path.join(self._serialization_dir, f"metrics_epoch_{epoch}.json"), metrics
                )

            # The Scheduler API is agnostic to whether your schedule requires a validation metric -
            # if it doesn't, the validation metric passed here is ignored.
            if self._learning_rate_scheduler:
                self._learning_rate_scheduler.step(this_epoch_val_metric)
            if self._momentum_scheduler:
                self._momentum_scheduler.step(this_epoch_val_metric)

            if self._master:
                self._checkpointer.save_checkpoint(
                    epoch, self, is_best_so_far=self._metric_tracker.is_best_so_far()
                )

            # Wait for the master to finish saving the checkpoint
            if self._distributed:
                dist.barrier()

            for callback in self._epoch_callbacks:
                callback(self, metrics=metrics, epoch=epoch, is_master=self._master)

            epoch_elapsed_time = time.time() - epoch_start_time
            logger.info("Epoch duration: %s", datetime.timedelta(seconds=epoch_elapsed_time))

            if epoch < self._num_epochs - 1:
                training_elapsed_time = time.time() - training_start_time
                estimated_time_remaining = training_elapsed_time * (
                    (self._num_epochs - epoch_counter) / float(epoch - epoch_counter + 1) - 1
                )
                formatted_time = str(datetime.timedelta(seconds=int(estimated_time_remaining)))
                logger.info("Estimated training time remaining: %s", formatted_time)

            epochs_trained += 1

        # make sure pending events are flushed to disk and files are closed properly
        self._tensorboard.close()

        # Load the best model state before returning
        best_model_state = self._checkpointer.best_model_state()
        if best_model_state:
            self.model.load_state_dict(best_model_state)

        return metrics
Пример #35
0
def train_lm(
        data_dir: str,
        model_dir: str,
        dataset: str,
        baseline: str,
        hyper_params: Dict[str, Any],
        loss_type: str,
        compute_train_batch_size: int,
        predict_batch_size: int,
        gpu_ids: Optional[List[int]],
        logger: Optional[logging.Logger] = None
) -> None:
    """Fine-tune a pre-trained LM baseline on a scruples dataset.

    Fine-tune ``baseline`` on ``dataset``, writing all results and
    artifacts to ``model_dir``. Return the best calibrated xentropy achieved on
    dev after any epoch.

    Parameters
    ----------
    data_dir : str
        The path to the directory containing the dataset.
    model_dir : str
        The path to the directory in which to save results.
    dataset : str
        The dataset to use when fine-tuning ``baseline``. Must be either
        "resource" or "corpus".
    baseline : str
        The pre-trained LM to fine-tune. Should be one of the keys for
        ``scruples.baselines.$dataset.FINE_TUNE_LM_BASELINES`` where
        ``$dataset`` corresponds to the ``dataset`` argument to this
        function.
    hyper_params : Dict[str, Any]
        The dictionary of hyper-parameters for the model.
    loss_type : str
        The type of loss to use. Should be one of ``"xentropy-hard"``,
        ``"xentropy-soft"``, ``"xentropy-full"`` or
        ``"dirichlet-multinomial"``.
    compute_train_batch_size : int
        The largest batch size that will fit on the hardware during
        training. Gradient accumulation will be used to make sure the
        actual size of the batch on the hardware respects this limit.
    predict_batch_size : int
        The number of instances to use in a predicting batch.
    gpu_ids : Optional[List[int]]
        A list of IDs for GPUs to use.
    logger : Optional[logging.Logger], optional (default=None)
        The logger to use when logging messages. If ``None``, then no
        messages will be logged.

    Returns
    -------
    float
        The best calibrated xentropy on dev achieved after any epoch.
    bool
        ``True`` if the training loss diverged, ``False`` otherwise.
    """
    gc.collect()
    # collect any garbage to make sure old torch objects are cleaned up (and
    # their memory is freed from the GPU). Otherwise, old tensors can hang
    # around on the GPU, causing CUDA out-of-memory errors.

    if loss_type not in settings.LOSS_TYPES:
        raise ValueError(
            f'Unrecognized loss type: {loss_type}. Please use one of'
            f' "xentropy-hard", "xentropy-soft", "xentropy-full" or'
            f' "dirichlet-multinomial".')

    # Step 1: Manage and construct paths.

    if logger is not None:
        logger.info('Creating the model directory.')

    checkpoints_dir = os.path.join(model_dir, 'checkpoints')
    tensorboard_dir = os.path.join(model_dir, 'tensorboard')
    os.makedirs(model_dir)
    os.makedirs(checkpoints_dir)
    os.makedirs(tensorboard_dir)

    config_file_path = os.path.join(model_dir, 'config.json')
    log_file_path = os.path.join(model_dir, 'log.txt')
    best_checkpoint_path = os.path.join(
        checkpoints_dir, 'best.checkpoint.pkl')
    last_checkpoint_path = os.path.join(
        checkpoints_dir, 'last.checkpoint.pkl')

    # Step 2: Setup the log file.

    if logger is not None:
        logger.info('Configuring log files.')

    log_file_handler = logging.FileHandler(log_file_path)
    log_file_handler.setLevel(logging.DEBUG)
    log_file_handler.setFormatter(logging.Formatter(settings.LOG_FORMAT))
    logging.root.addHandler(log_file_handler)

    # Step 3: Record the script's arguments.

    if logger is not None:
        logger.info(f'Writing arguments to {config_file_path}.')

    with open(config_file_path, 'w') as config_file:
        json.dump({
            'data_dir': data_dir,
            'model_dir': model_dir,
            'dataset': dataset,
            'baseline': baseline,
            'hyper_params': hyper_params,
            'loss_type': loss_type,
            'compute_train_batch_size': compute_train_batch_size,
            'predict_batch_size': predict_batch_size,
            'gpu_ids': gpu_ids
        }, config_file)

    # Step 4: Configure GPUs.

    if gpu_ids:
        if logger is not None:
            logger.info(
                f'Configuring environment to use {len(gpu_ids)} GPUs:'
                f' {", ".join(str(gpu_id) for gpu_id in gpu_ids)}.')

        os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(map(str, gpu_ids))

        if not torch.cuda.is_available():
            raise EnvironmentError('CUDA must be available to use GPUs.')

        device = torch.device('cuda')
    else:
        if logger is not None:
            logger.info('Configuring environment to use CPU.')

        device = torch.device('cpu')

    # Step 5: Fetch the baseline information and training loop parameters.

    if logger is not None:
        logger.info('Retrieving baseline and related parameters.')

    if dataset == 'resource':
        Model, baseline_config, _, make_transform =\
            resource.FINE_TUNE_LM_BASELINES[baseline]
    elif dataset == 'corpus':
        Model, baseline_config, _, make_transform =\
            corpus.FINE_TUNE_LM_BASELINES[baseline]
    else:
        raise ValueError(
            f'dataset must be either "resource" or "corpus", not'
            f' {dataset}.')

    n_epochs = hyper_params['n_epochs']
    train_batch_size = hyper_params['train_batch_size']
    n_gradient_accumulation = math.ceil(
        train_batch_size / (compute_train_batch_size * len(gpu_ids)))

    # Step 6: Load the dataset.

    if logger is not None:
        logger.info(f'Loading the dataset from {data_dir}.')

    featurize = make_transform(**baseline_config['transform'])
    if dataset == 'resource':
        Dataset = ScruplesResourceDataset
        labelize = None
        labelize_scores = lambda scores: np.array(scores).astype(float)
    elif dataset == 'corpus':
        Dataset = ScruplesCorpusDataset
        labelize = lambda s: getattr(Label, s).index
        labelize_scores = lambda scores: np.array([
            score
            for _, score in sorted(
                    scores.items(),
                    key=lambda t: labelize(t[0]))
        ]).astype(float)
    else:
        raise ValueError(
            f'dataset must be either "resource" or "corpus", not'
            f' {dataset}.')

    train = Dataset(
        data_dir=data_dir,
        split='train',
        transform=featurize,
        label_transform=labelize,
        label_scores_transform=labelize_scores)
    dev = Dataset(
        data_dir=data_dir,
        split='dev',
        transform=featurize,
        label_transform=labelize,
        label_scores_transform=labelize_scores)

    train_loader = DataLoader(
        dataset=train,
        batch_size=train_batch_size // n_gradient_accumulation,
        shuffle=True,
        num_workers=len(gpu_ids),
        pin_memory=bool(gpu_ids))
    dev_loader = DataLoader(
        dataset=dev,
        batch_size=predict_batch_size,
        shuffle=False,
        num_workers=len(gpu_ids),
        pin_memory=bool(gpu_ids))

    # Step 7: Create the model, optimizer, and loss.

    if logger is not None:
        logger.info('Initializing the model.')

    model = Model(**baseline_config['model'])
    model.to(device)

    n_optimization_steps = n_epochs * math.ceil(len(train) / train_batch_size)
    parameter_groups = [
        {
            'params': [
                param
                for name, param in model.named_parameters()
                if 'bias' in name
                or 'LayerNorm.bias' in name
                or 'LayerNorm.weight' in name
            ],
            'weight_decay': 0
        },
        {
            'params': [
                param
                for name, param in model.named_parameters()
                if 'bias' not in name
                and 'LayerNorm.bias' not in name
                and 'LayerNorm.weight' not in name
            ],
            'weight_decay': hyper_params['weight_decay']
        }
    ]
    optimizer = AdamW(parameter_groups, lr=hyper_params['lr'])

    if loss_type == 'xentropy-hard':
        loss = torch.nn.CrossEntropyLoss()
    elif loss_type == 'xentropy-soft':
        loss = SoftCrossEntropyLoss()
    elif loss_type == 'xentropy-full':
        loss = SoftCrossEntropyLoss()
    elif loss_type == 'dirichlet-multinomial':
        loss = DirichletMultinomialLoss()

    xentropy = SoftCrossEntropyLoss()

    scheduler = WarmupLinearSchedule(
        optimizer=optimizer,
        warmup_steps=int(
            hyper_params['warmup_proportion']
            * n_optimization_steps
        ),
        t_total=n_optimization_steps)

    # add data parallelism support
    model = torch.nn.DataParallel(model)

    # Step 8: Run training.

    n_train_batches_per_epoch = math.ceil(len(train) / train_batch_size)
    n_dev_batch_per_epoch = math.ceil(len(dev) / predict_batch_size)

    writer = tensorboardX.SummaryWriter(log_dir=tensorboard_dir)

    best_dev_calibrated_xentropy = math.inf
    for epoch in range(n_epochs):
        # set the model to training mode
        model.train()

        # run training for the epoch
        epoch_train_loss = 0
        epoch_train_xentropy = 0
        for i, (_, features, labels, label_scores) in tqdm.tqdm(
                enumerate(train_loader),
                total=n_gradient_accumulation * n_train_batches_per_epoch,
                **settings.TQDM_KWARGS
        ):
            # move the data onto the device
            features = {k: v.to(device) for k, v in features.items()}

            # create the targets
            if loss_type == 'xentropy-hard':
                targets = labels
            elif loss_type == 'xentropy-soft':
                targets = label_scores / torch.unsqueeze(
                    torch.sum(label_scores, dim=-1), dim=-1)
            elif loss_type == 'xentropy-full':
                targets = label_scores
            elif loss_type == 'dirichlet-multinomial':
                targets = label_scores
            # create the soft labels
            soft_labels = label_scores / torch.unsqueeze(
                torch.sum(label_scores, dim=-1), dim=-1)

            # move the targets and soft labels to the device
            targets = targets.to(device)
            soft_labels = soft_labels.to(device)

            # make predictions
            logits = model(**features)[0]

            batch_loss = loss(logits, targets)
            batch_xentropy = xentropy(logits, soft_labels)

            # update training statistics
            epoch_train_loss = (
                batch_loss.item() + i * epoch_train_loss
            ) / (i + 1)
            epoch_train_xentropy = (
                batch_xentropy.item() + i * epoch_train_xentropy
            ) / (i + 1)

            # update the network
            batch_loss.backward()

            if (i + 1) % n_gradient_accumulation == 0:
                optimizer.step()
                optimizer.zero_grad()

                scheduler.step()

            # write training statistics to tensorboard

            step = n_train_batches_per_epoch * epoch + (
                (i + 1) // n_gradient_accumulation)
            if step % 100 == 0 and (i + 1) % n_gradient_accumulation == 0:
                writer.add_scalar('train/loss', epoch_train_loss, step)
                writer.add_scalar('train/xentropy', epoch_train_xentropy, step)

        # run evaluation
        with torch.no_grad():
            # set the model to evaluation mode
            model.eval()

            # run validation for the epoch
            epoch_dev_loss = 0
            epoch_dev_soft_labels = []
            epoch_dev_logits = []
            for i, (_, features, labels, label_scores) in tqdm.tqdm(
                    enumerate(dev_loader),
                    total=n_dev_batch_per_epoch,
                    **settings.TQDM_KWARGS):
                # move the data onto the device
                features = {k: v.to(device) for k, v in features.items()}

                # create the targets
                if loss_type == 'xentropy-hard':
                    targets = labels
                elif loss_type == 'xentropy-soft':
                    targets = label_scores / torch.unsqueeze(
                        torch.sum(label_scores, dim=-1), dim=-1)
                elif loss_type == 'xentropy-full':
                    targets = label_scores
                elif loss_type == 'dirichlet-multinomial':
                    targets = label_scores

                # move the targets to the device
                targets = targets.to(device)

                # make predictions
                logits = model(**features)[0]

                batch_loss = loss(logits, targets)

                # update validation statistics
                epoch_dev_loss = (
                    batch_loss.item() + i * epoch_dev_loss
                ) / (i + 1)
                epoch_dev_soft_labels.extend(
                    (
                        label_scores
                        / torch.unsqueeze(torch.sum(label_scores, dim=-1), dim=-1)
                    ).cpu().numpy().tolist()
                )
                epoch_dev_logits.extend(logits.cpu().numpy().tolist())

            # compute validation statistics
            epoch_dev_soft_labels = np.array(epoch_dev_soft_labels)
            epoch_dev_logits = np.array(epoch_dev_logits)

            calibration_factor = utils.calibration_factor(
                logits=epoch_dev_logits,
                targets=epoch_dev_soft_labels)

            epoch_dev_xentropy = utils.xentropy(
                y_true=epoch_dev_soft_labels,
                y_pred=softmax(epoch_dev_logits, axis=-1))
            epoch_dev_calibrated_xentropy = utils.xentropy(
                y_true=epoch_dev_soft_labels,
                y_pred=softmax(epoch_dev_logits / calibration_factor, axis=-1))

            # write validation statistics to tensorboard
            writer.add_scalar('dev/loss', epoch_dev_loss, step)
            writer.add_scalar('dev/xentropy', epoch_dev_xentropy, step)
            writer.add_scalar(
                'dev/calibrated-xentropy', epoch_dev_calibrated_xentropy, step)

            if logger is not None:
                logger.info(
                    f'\n\n'
                    f'  epoch {epoch}:\n'
                    f'    train loss              : {epoch_train_loss:.4f}\n'
                    f'    train xentropy          : {epoch_train_xentropy:.4f}\n'
                    f'    dev loss                : {epoch_dev_loss:.4f}\n'
                    f'    dev xentropy            : {epoch_dev_xentropy:.4f}\n'
                    f'    dev calibrated xentropy : {epoch_dev_calibrated_xentropy:.4f}\n'
                    f'    calibration factor      : {calibration_factor:.4f}\n')

        # update checkpoints

        torch.save(
            {
                'epoch': epoch,
                'model': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'calibration_factor': calibration_factor
            },
            last_checkpoint_path)

        # update the current best model
        if epoch_dev_calibrated_xentropy < best_dev_calibrated_xentropy:
            shutil.copyfile(last_checkpoint_path, best_checkpoint_path)
            best_dev_calibrated_xentropy = epoch_dev_calibrated_xentropy

        # exit early if the training loss has diverged
        if math.isnan(epoch_train_loss):
            logger.info('Training loss has diverged. Exiting early.')

            return best_dev_calibrated_xentropy, True

    logger.info(
        f'Training complete. Best dev calibrated xentropy was'
        f' {best_dev_calibrated_xentropy:.4f}.')

    return best_dev_calibrated_xentropy, False
Пример #36
0
    def _epoch(self, data: LabelledCollection, posteriors, iterations, epoch,
               early_stop, train):
        mse_loss = MSELoss()

        self.quanet.train(mode=train)
        losses = []
        mae_errors = []
        if train == False:
            prevpoints = F.get_nprevpoints_approximation(
                iterations, self.quanet.n_classes)
            iterations = F.num_prevalence_combinations(prevpoints,
                                                       self.quanet.n_classes)
            with qp.util.temp_seed(0):
                sampling_index_gen = data.artificial_sampling_index_generator(
                    self.sample_size, prevpoints)
        else:
            sampling_index_gen = [
                data.sampling_index(self.sample_size, *prev) for prev in
                F.uniform_simplex_sampling(data.n_classes, iterations)
            ]
        pbar = tqdm(sampling_index_gen,
                    total=iterations) if train else sampling_index_gen

        for it, index in enumerate(pbar):
            sample_data = data.sampling_from_index(index)
            sample_posteriors = posteriors[index]
            quant_estims = self._get_aggregative_estims(sample_posteriors)
            ptrue = torch.as_tensor([sample_data.prevalence()],
                                    dtype=torch.float,
                                    device=self.device)
            if train:
                self.optim.zero_grad()
                phat = self.quanet.forward(sample_data.instances,
                                           sample_posteriors, quant_estims)
                loss = mse_loss(phat, ptrue)
                mae = mae_loss(phat, ptrue)
                loss.backward()
                self.optim.step()
            else:
                with torch.no_grad():
                    phat = self.quanet.forward(sample_data.instances,
                                               sample_posteriors, quant_estims)
                    loss = mse_loss(phat, ptrue)
                    mae = mae_loss(phat, ptrue)

            losses.append(loss.item())
            mae_errors.append(mae.item())

            mse = np.mean(losses)
            mae = np.mean(mae_errors)
            if train:
                self.status['tr-loss'] = mse
                self.status['tr-mae'] = mae
            else:
                self.status['va-loss'] = mse
                self.status['va-mae'] = mae

            if train:
                pbar.set_description(
                    f'[QuaNet] '
                    f'epoch={epoch} [it={it}/{iterations}]\t'
                    f'tr-mseloss={self.status["tr-loss"]:.5f} tr-maeloss={self.status["tr-mae"]:.5f}\t'
                    f'val-mseloss={self.status["va-loss"]:.5f} val-maeloss={self.status["va-mae"]:.5f} '
                    f'patience={early_stop.patience}/{early_stop.PATIENCE_LIMIT}'
                )
    def perturb(self, X, y):
        batch_size, c, h, w = X.size()

        self.initialize_cost(X, inf=self.inf)
        pi = self.initialize_coupling(X).clone().detach().requires_grad_(True)
        normalization = X.sum(dim=(1, 2, 3), keepdim=True)

        for t in range(self.nb_iter):
            adv_example = self.coupling2adversarial(pi, X)
            scores = self.predict(
                adv_example.clamp(min=self.clip_min, max=self.clip_max))

            loss = self.loss_fn(scores, y)
            loss.backward()

            with torch.no_grad():
                self.lst_loss.append(loss.item())
                self.lst_acc.append((scores.max(dim=1)[1] == y).sum().item())
                """Add a small constant to enhance numerical stability"""
                pi.grad /= (
                    tensor_norm(pi.grad, p='inf').view(batch_size, 1, 1, 1) +
                    1e-35)
                assert (pi.grad == pi.grad).all() and (
                    pi.grad != float('inf')).all() and (pi.grad !=
                                                        float('-inf')).all()

                start = torch.cuda.Event(enable_timing=True)
                end = torch.cuda.Event(enable_timing=True)

                start.record()

                optimal_pi, num_iter = entr_support_func(
                    pi.grad,
                    X,
                    cost=self.cost,
                    inf=self.inf,
                    eps=self.eps * normalization.squeeze(),
                    gamma=self.entrp_gamma,
                    dual_max_iter=self.dual_max_iter,
                    grad_tol=self.grad_tol,
                    int_tol=self.int_tol)
                end.record()

                torch.cuda.synchronize()

                self.run_time += start.elapsed_time(end)
                self.num_iter += num_iter
                self.func_calls += 1

                if self.verbose and (t + 1) % 10 == 0:
                    print(
                        "num of iters : {:4d}, ".format(t + 1),
                        "loss : {:12.6f}, ".format(loss.item()),
                        "acc : {:5.2f}%, ".format(
                            (scores.max(dim=1)[1] == y).sum().item() /
                            batch_size * 100),
                        "dual iter : {:2d}, ".format(num_iter),
                        "per iter time : {:7.3f}ms".format(
                            start.elapsed_time(end) / num_iter))

                step = 2. / (t + 2)
                pi += step * (optimal_pi - pi)
                pi.grad.zero_()

                self.check_nonnegativity(pi / normalization,
                                         tol=1e-6,
                                         verbose=False)
                self.check_marginal_constraint(pi / normalization,
                                               X / normalization,
                                               tol=1e-5,
                                               verbose=False)
                self.check_transport_cost(pi / normalization,
                                          tol=1e-3,
                                          verbose=False)

        with torch.no_grad():
            adv_example = self.coupling2adversarial(pi, X)
            check_hypercube(adv_example, verbose=self.verbose)
            self.check_nonnegativity(pi / normalization,
                                     tol=1e-4,
                                     verbose=self.verbose)
            self.check_marginal_constraint(pi / normalization,
                                           X / normalization,
                                           tol=1e-4,
                                           verbose=self.verbose)
            self.check_transport_cost(pi / normalization,
                                      tol=self.eps * 1e-3,
                                      verbose=self.verbose)

            if self.postprocess is True:
                if self.verbose:
                    print("==========> post-processing projection........")

                pi = dual_capacity_constrained_projection(
                    pi,
                    X,
                    self.cost,
                    eps=self.eps * normalization.squeeze(),
                    transpose_idx=self.forward_idx,
                    detranspose_idx=self.backward_idx,
                    coupling2adversarial=self.coupling2adversarial)

                adv_example = self.coupling2adversarial(pi, X)
                check_hypercube(adv_example,
                                tol=self.eps * 1e-1,
                                verbose=self.verbose)
                self.check_nonnegativity(pi / normalization,
                                         tol=1e-6,
                                         verbose=self.verbose)
                self.check_marginal_constraint(pi / normalization,
                                               X / normalization,
                                               tol=1e-6,
                                               verbose=self.verbose)
                self.check_transport_cost(pi / normalization,
                                          tol=self.eps * 1e-3,
                                          verbose=self.verbose)
        """Do not clip the adversarial examples to preserve pixel mass"""
        return adv_example
Пример #38
0
    def translate_batch(self, src_seq, src_pos):
        ''' Translation work in one batch '''

        def get_inst_idx_to_tensor_position_map(inst_idx_list):
            ''' Indicate the position of an instance in a tensor. '''
            return {inst_idx: tensor_position for tensor_position, inst_idx in enumerate(inst_idx_list)}

        def collect_active_part(beamed_tensor, curr_active_inst_idx, n_prev_active_inst, n_bm):
            ''' Collect tensor parts associated to active instances. '''

            _, *d_hs = beamed_tensor.size()
            n_curr_active_inst = len(curr_active_inst_idx)
            new_shape = (n_curr_active_inst * n_bm, *d_hs)

            beamed_tensor = beamed_tensor.view(n_prev_active_inst, -1)
            beamed_tensor = beamed_tensor.index_select(0, curr_active_inst_idx)
            beamed_tensor = beamed_tensor.view(*new_shape)

            return beamed_tensor

        def collate_active_info(
                src_seq, src_enc, inst_idx_to_position_map, active_inst_idx_list):
            # Sentences which are still active are collected,
            # so the decoder will not run on completed sentences.
            n_prev_active_inst = len(inst_idx_to_position_map)
            active_inst_idx = [inst_idx_to_position_map[k] for k in active_inst_idx_list]
            active_inst_idx = torch.LongTensor(active_inst_idx).to(self.device)

            active_src_seq = collect_active_part(src_seq, active_inst_idx, n_prev_active_inst, n_bm)
            active_src_enc = collect_active_part(src_enc, active_inst_idx, n_prev_active_inst, n_bm)
            active_inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(active_inst_idx_list)

            return active_src_seq, active_src_enc, active_inst_idx_to_position_map

        def beam_decode_step(
                inst_dec_beams, len_dec_seq, src_seq, enc_output, inst_idx_to_position_map, n_bm):
            ''' Decode and update beam status, and then return active beam idx '''

            def prepare_beam_dec_seq(inst_dec_beams, len_dec_seq):
                dec_partial_seq = [b.get_current_state() for b in inst_dec_beams if not b.done]
                dec_partial_seq = torch.stack(dec_partial_seq).to(self.device)
                dec_partial_seq = dec_partial_seq.view(-1, len_dec_seq)
                return dec_partial_seq

            def prepare_beam_dec_pos(len_dec_seq, n_active_inst, n_bm):
                dec_partial_pos = torch.arange(1, len_dec_seq + 1, dtype=torch.long, device=self.device)
                dec_partial_pos = dec_partial_pos.unsqueeze(0).repeat(n_active_inst * n_bm, 1)
                return dec_partial_pos

            def predict_word(dec_seq, dec_pos, src_seq, enc_output, n_active_inst, n_bm):
                dec_output, *_ = self.model.decoder(dec_seq, dec_pos, src_seq, enc_output)
                dec_output = dec_output[:, -1, :]  # Pick the last step: (bh * bm) * d_h
                word_prob = F.log_softmax(self.model.tgt_word_prj(dec_output), dim=1)
                word_prob = word_prob.view(n_active_inst, n_bm, -1)

                return word_prob

            def collect_active_inst_idx_list(inst_beams, word_prob, inst_idx_to_position_map):
                active_inst_idx_list = []
                for inst_idx, inst_position in inst_idx_to_position_map.items():
                    is_inst_complete = inst_beams[inst_idx].advance(word_prob[inst_position])
                    if not is_inst_complete:
                        active_inst_idx_list += [inst_idx]

                return active_inst_idx_list

            n_active_inst = len(inst_idx_to_position_map)

            dec_seq = prepare_beam_dec_seq(inst_dec_beams, len_dec_seq)
            dec_pos = prepare_beam_dec_pos(len_dec_seq, n_active_inst, n_bm)
            word_prob = predict_word(dec_seq, dec_pos, src_seq, enc_output, n_active_inst, n_bm)

            # Update the beam with predicted word prob information and collect incomplete instances
            active_inst_idx_list = collect_active_inst_idx_list(
                inst_dec_beams, word_prob, inst_idx_to_position_map)

            return active_inst_idx_list

        def collect_hypothesis_and_scores(inst_dec_beams, n_best):
            all_hyp, all_scores = [], []
            for inst_idx in range(len(inst_dec_beams)):
                scores, tail_idxs = inst_dec_beams[inst_idx].sort_scores()
                all_scores += [scores[:n_best]]

                hyps = [inst_dec_beams[inst_idx].get_hypothesis(i) for i in tail_idxs[:n_best]]
                all_hyp += [hyps]
            return all_hyp, all_scores

        with torch.no_grad():
            #-- Encode
            src_seq, src_pos = src_seq.to(self.device), src_pos.to(self.device)
            src_enc, *_ = self.model.encoder(src_seq, src_pos)

            #-- Repeat data for beam search
            n_bm = self.opt.beam_size
            n_inst, len_s, d_h = src_enc.size()
            src_seq = src_seq.repeat(1, n_bm).view(n_inst * n_bm, len_s)
            src_enc = src_enc.repeat(1, n_bm, 1).view(n_inst * n_bm, len_s, d_h)

            #-- Prepare beams
            inst_dec_beams = [Beam(n_bm, device=self.device) for _ in range(n_inst)]

            #-- Bookkeeping for active or not
            active_inst_idx_list = list(range(n_inst))
            inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(active_inst_idx_list)

            #-- Decode
            for len_dec_seq in range(1, self.model_opt.max_token_seq_len + 1):

                active_inst_idx_list = beam_decode_step(
                    inst_dec_beams, len_dec_seq, src_seq, src_enc, inst_idx_to_position_map, n_bm)

                if not active_inst_idx_list:
                    break  # all instances have finished their path to <EOS>

                src_seq, src_enc, inst_idx_to_position_map = collate_active_info(
                    src_seq, src_enc, inst_idx_to_position_map, active_inst_idx_list)

        batch_hyp, batch_scores = collect_hypothesis_and_scores(inst_dec_beams, self.opt.n_best)

        return batch_hyp, batch_scores
def main():
    PRE_TRAINED = 0
    VGG16_PATH = './weight/FashionMNIST_vgg16.pth'
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(device)
    # resnet18 = models.resnet18()
    # alexnet = models.alexnet()
    vgg16 = models.vgg16()
    # squeezenet = models.squeezenet1_0()
    # densenet = models.densenet161()
    # inception = models.inception_v3()
    # googlenet = models.googlenet()
    # shufflenet = models.shufflenet_v2_x1_0()
    # mobilenet = models.mobilenet_v2()
    # resnext50_32x4d = models.resnext50_32x4d()
    # wide_resnet50_2 = models.wide_resnet50_2()
    # mnasnet = models.mnasnet1_0()
    vgg16.to(device)

    trainset_fashion = torchvision.datasets.FashionMNIST(
        root='./data/pytorch/FashionMNIST',
        train=True,
        download=True,
        transform=transforms.Compose([transforms.ToTensor()]))

    testset_fashion = torchvision.datasets.FashionMNIST(
        root='./data/pytorch/FashionMNIST',
        train=False,
        download=True,
        transform=transforms.Compose([transforms.ToTensor()]))

    trainloader_fashion = torch.utils.data.DataLoader(trainset_fashion, batch_size=4,
                                                      shuffle=True, num_workers=2)
    testloader_fashion = torch.utils.data.DataLoader(testset_fashion, batch_size=4,
                                                     shuffle=False, num_workers=2)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(vgg16.parameters(), lr=0.001, momentum=0.9)
    if (PRE_TRAINED):
        vgg16.load_state_dict(torch.load(VGG16_PATH))
    else:
        start_time = time.time()
        print("Start Training >>>")
        for epoch in range(2):
            running_loss = 0.0
            for i, data in enumerate(trainloader_fashion, 0):
                inputs, labels = data[0].to(device), data[1].to(device)
                inputs = inputs.repeat(1, 3, 2, 2)
                optimizer.zero_grad()
                outputs = vgg16(inputs)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()
                running_loss += loss.item()
                if i % 2000 == 1999:
                    print(f'[Epoch: {epoch + 1}, Batch: {i + 1}] loss: {running_loss / 2000}')
                    running_loss = 0.0
        train_time = (time.time() - start_time) / 60
        torch.save(vgg16.state_dict(), VGG16_PATH)
        print('>>> Finished Training')
        print(f'Training time: {train_time} mins.')

    start_test = time.time()
    print("\nStart Testing >>>")
    correct = 0
    total = 0
    with torch.no_grad():
        for i, data in enumerate(testloader_fashion, 0):
            images, labels = data[0].to(device), data[1].to(device)
            images = images.repeat(1, 3, 2, 2)
            outputs = vgg16(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            if i % 2000 == 1999:
                print(f'Testing Batch: {i + 1}')
    test_time = (time.time() - start_test) / 60
    print('>>> Finished Testing')
    print(f'Testing time: {test_time} mins.')
    print(f'Accuracy: {100 * correct / total}')
def Train(model_name, task, resume, check_name, saved_model, batch_size, lr, momentum, weight_decay, lr_factor, end_epoch):

    # Task
    task_list = {
        'collar_design_labels': 5,
        'skirt_length_labels': 6,
        'lapel_design_labels': 5,
        'neckline_design_labels': 10,
        'coat_length_labels': 8,
        'neck_design_labels': 5,
        'pant_length_labels': 6,
        'sleeve_length_labels': 9
    }
    num_classes = task_list[task]

    # Model
    print('==> Building model..')
    net = build_model(model_name=model_name, num_classes=num_classes, pretrained=True)
    if resume != 0:
        print('==> Resuming from checkpoint..')
        checkpoint_ = torch.load(check_name)
        net0 = torch.load(saved_model)
        update_state = net0.module.state_dict()
        net.load_state_dict(update_state)

        best_loss = checkpoint_['loss']
        best_map = checkpoint_['map']
        best_acc = checkpoint_['acc']
        start_epoch = checkpoint_['epoch']
        history = checkpoint_['history']
    else:
        # best_loss = float('inf')
        best_acc = 0.
        best_map = 0.
        start_epoch = 0
        history = {'train_loss': [], 'test_loss': [], 'train_map': [], 'test_map': [],
                   'train_acc': [], 'test_acc': []}

    # Data
    data_root = './data/train_valid/'
    traindir = os.path.join(data_root + task, 'train')
    testdir = os.path.join(data_root + task, 'val')

    train_transform = transforms.Compose([
        transforms.Resize(512),
        transforms.RandomRotation(15.0),
        transforms.CenterCrop(500),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(brightness=0.15),
        transforms.ColorJitter(contrast=0.15),
        transforms.ColorJitter(saturation=0.15),
        transforms.RandomGrayscale(0.05),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5),
                             std=(0.5, 0.5, 0.5))
    ])

    test_transform = transforms.Compose([
        transforms.Resize(512),
        transforms.CenterCrop(500),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5),
                             std=(0.5, 0.5, 0.5))
    ])

    trainset = datasets.ImageFolder(traindir, transform=train_transform)
    testset = datasets.ImageFolder(testdir, transform=test_transform)

    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=30)
    testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=True, num_workers=30)

    # use gpu
    net.cuda()
    print('use %d GPU' % torch.cuda.device_count())
    net = torch.nn.DataParallel(net, device_ids=range(torch.cuda.device_count()))
    cudnn.benchmark = True

    # loss and optimizer
    criterion = nn.CrossEntropyLoss().cuda()
    # optimizer = torch.optim.SGD(net.parameters(), lr=lr,
    #                             momentum= momentum,
    #                             weight_decay= weight_decay)

    optimizer = torch.optim.Adam(net.parameters(), lr=lr)

    # scheduler
    # scheduler = StepLR(optimizer, step_size=60, gamma=lr_factor)

    # training
    logging.info('Start Training for %s' % task)
    for epoch in range(start_epoch, end_epoch):
        ts = time.time()
        # scheduler.step()

        # train
        net.train()
        train_loss = 0
        train_AP = 0.
        train_AP_cnt = 0
        train_correct = 0
        print("Training...")
        for batch_idx, (inputs, targets) in enumerate(trainloader):
            inputs = Variable(inputs.cuda(), requires_grad=False)
            targets = Variable(targets.cuda(), requires_grad=False)

            optimizer.zero_grad()
            outputs = net(inputs)
            preds = outputs.data.max(1, keepdim=True)[1]
            loss = criterion(outputs, targets)
            ap, cnt = calculate_ap(labels=targets, outputs=outputs)
            loss.backward()
            optimizer.step()

            train_loss += float(loss.data.item())
            train_correct += int(preds.eq(targets.data.view_as(preds)).long().cpu().sum())
            train_AP += ap
            train_AP_cnt += cnt

            if batch_idx % 50 == 0:
                print("epoch {0} / batch index {1} : loss {2}".format(epoch + 1, batch_idx, train_loss / (batch_idx + 1)));
            
        train_loss_epoch = train_loss / (batch_idx + 1)
        train_acc_epoch = train_correct / len(trainloader.dataset)
        train_map_epoch = train_AP / train_AP_cnt
        history['train_loss'].append(train_loss_epoch)
        history['train_acc'].append(train_acc_epoch)
        history['train_map'].append(train_map_epoch)

        #test
        net.eval()
        test_loss = 0
        test_AP = 0.
        test_AP_cnt = 0
        test_correct = 0
        print("Testing...")
        with torch.no_grad():
            for batch_idx, (inputs, targets) in enumerate(testloader):
                inputs = inputs.cuda()
                targets = Variable(targets.cuda())

                outputs = net(Variable(inputs))
                preds = outputs.data.max(1, keepdim=True)[1]
                loss = criterion(outputs, targets)
                ap, cnt = calculate_ap(labels=targets, outputs=outputs)
                
                test_loss += float(loss.data.item())
                test_correct += int(preds.eq(targets.data.view_as(preds)).long().cpu().sum())
                test_AP += ap
                test_AP_cnt += cnt
            
        test_loss_epoch = test_loss / (batch_idx + 1)
        test_acc_epoch = test_correct / len(testloader.dataset)
        test_map_epoch = test_AP / test_AP_cnt
        history['test_loss'].append(test_loss_epoch)
        history['test_acc'].append(test_acc_epoch)
        history['test_map'].append(test_map_epoch)
        
        print("epoch[{0}/{1}]: test_acc: {2}.".format(epoch + 1, end_epoch, test_acc_epoch))

        time_cost = time.time() - ts
        logging.info('epoch[%d/%d]: train_loss: %.3f | test_loss: %.3f || train_map: %.3f | test_map: %.3f || train_acc: %.3f | test_acc: %.3f || time: %.1f'
            % (epoch + 1, end_epoch, train_loss_epoch, test_loss_epoch,
               train_map_epoch, test_map_epoch, 100 * train_acc_epoch, 100 * test_acc_epoch, time_cost))

        # save checkpoint model
        if test_acc_epoch > best_acc:  # test_map_epoch > best_map
            print('Saving..')
            state = {
                #'net': net.module.state_dict(),
                'loss': test_loss_epoch,
                'epoch': epoch,
                'map': test_map_epoch,
                'acc': test_acc_epoch,
                'history': history
            }
            if not os.path.isdir(os.path.dirname(check_name)):
                os.mkdir(os.path.dirname(check_name))
            torch.save(state, check_name)
            # save model
            net_ = copy.deepcopy(net)
            net_.cpu()
            torch.save(net_, saved_model)
            best_acc = test_acc_epoch  # best_map = test_map_epoch

    return net
Пример #41
0
def run_test():
    print('Starting model test.....')
    model.eval()  # Set model to evaluate mode.

    list_loss = []
    list_qloss = []
    list_ploss = []
    list_minade2, list_avgade2 = [], []
    list_minfde2, list_avgfde2 = [], []
    list_minade3, list_avgade3 = [], []
    list_minfde3, list_avgfde3 = [], []
    list_minmsd, list_avgmsd = [], []

    list_dao = []
    list_dac = []

    for test_time_ in range(test_times):

        epoch_loss = 0.0
        epoch_qloss = 0.0
        epoch_ploss = 0.0
        epoch_minade2, epoch_avgade2 = 0.0, 0.0
        epoch_minfde2, epoch_avgfde2 = 0.0, 0.0
        epoch_minade3, epoch_avgade3 = 0.0, 0.0
        epoch_minfde3, epoch_avgfde3 = 0.0, 0.0
        epoch_minmsd, epoch_avgmsd = 0.0, 0.0
        epoch_agents, epoch_agents2, epoch_agents3 = 0.0, 0.0, 0.0

        epoch_dao = 0.0
        epoch_dac = 0.0
        dao_agents = 0.0
        dac_agents = 0.0

        H = W = 64
        with torch.no_grad():
            if map_version == '2.0':
                coordinate_2d = np.indices((H, W))
                coordinate = np.ravel_multi_index(coordinate_2d, dims=(H, W))
                coordinate = torch.FloatTensor(coordinate)
                coordinate = coordinate.reshape((1, 1, H, W))

                coordinate_std, coordinate_mean = torch.std_mean(coordinate)
                coordinate = (coordinate - coordinate_mean) / coordinate_std

                distance_2d = coordinate_2d - np.array([(H - 1) / 2, (H - 1) / 2]).reshape((2, 1, 1))
                distance = np.sqrt((distance_2d ** 2).sum(axis=0))
                distance = torch.FloatTensor(distance)
                distance = distance.reshape((1, 1, H, W))

                distance_std, distance_mean = torch.std_mean(distance)
                distance = (distance - distance_mean) / distance_std

                coordinate = coordinate.to(device)
                distance = distance.to(device)

            c1 = -decoding_steps * np.log(2 * np.pi)

            for b, batch in enumerate(data_loader):

                scene_images, log_prior, \
                agent_masks, \
                num_src_trajs, src_trajs, src_lens, src_len_idx, \
                num_tgt_trajs, tgt_trajs, tgt_lens, tgt_len_idx, \
                tgt_two_mask, tgt_three_mask, \
                decode_start_vel, decode_start_pos, scene_id, batch_size = batch

                # Detect dynamic batch size

                num_three_agents = torch.sum(tgt_three_mask)
                """
                if map_version == '2.0':
                    coordinate_batch = coordinate.repeat(batch_size, 1, 1, 1)
                    distance_batch = distance.repeat(batch_size, 1, 1, 1)
                    scene_images = torch.cat((scene_images.to(device), coordinate_batch, distance_batch), dim=1)
                """
                src_trajs = src_trajs.to(device)
                src_lens = src_lens.to(device)

                tgt_trajs = tgt_trajs.to(device)[tgt_three_mask]
                tgt_lens = tgt_lens.to(device)[tgt_three_mask]

                num_tgt_trajs = num_tgt_trajs.to(device)
                episode_idx = torch.arange(batch_size, device=device).repeat_interleave(num_tgt_trajs)[tgt_three_mask]

                agent_masks = agent_masks.to(device)
                agent_tgt_three_mask = torch.zeros_like(agent_masks)
                agent_masks_idx = torch.arange(len(agent_masks), device=device)[agent_masks][tgt_three_mask]
                agent_tgt_three_mask[agent_masks_idx] = True

                decode_start_vel = decode_start_vel.to(device)[agent_tgt_three_mask]
                decode_start_pos = decode_start_pos.to(device)[agent_tgt_three_mask]

                log_prior = log_prior.to(device)

                gen_trajs = model(src_trajs, src_lens, agent_tgt_three_mask, decode_start_vel, decode_start_pos, num_src_trajs, scene_images)

                gen_trajs = gen_trajs.reshape(num_three_agents, num_candidates, decoding_steps, 2)


                rs_error3 = ((gen_trajs - tgt_trajs.unsqueeze(1)) ** 2).sum(dim=-1).sqrt_()
                rs_error2 = rs_error3[..., :int(decoding_steps * 2 / 3)]

                diff = gen_trajs - tgt_trajs.unsqueeze(1)
                msd_error = (diff[:, :, :, 0] ** 2 + diff[:, :, :, 1] ** 2)

                num_agents = gen_trajs.size(0)
                num_agents2 = rs_error2.size(0)
                num_agents3 = rs_error3.size(0)

                ade2 = rs_error2.mean(-1)
                fde2 = rs_error2[..., -1]

                minade2, _ = ade2.min(dim=-1)
                avgade2 = ade2.mean(dim=-1)
                minfde2, _ = fde2.min(dim=-1)
                avgfde2 = fde2.mean(dim=-1)

                batch_minade2 = minade2.mean()
                batch_minfde2 = minfde2.mean()
                batch_avgade2 = avgade2.mean()
                batch_avgfde2 = avgfde2.mean()

                ade3 = rs_error3.mean(-1)
                fde3 = rs_error3[..., -1]

                msd = msd_error.mean(-1)
                minmsd, _ = msd.min(dim=-1)
                avgmsd = msd.mean(dim=-1)
                batch_minmsd = minmsd.mean()
                batch_avgmsd = avgmsd.mean()

                minade3, _ = ade3.min(dim=-1)
                avgade3 = ade3.mean(dim=-1)
                minfde3, _ = fde3.min(dim=-1)
                avgfde3 = fde3.mean(dim=-1)

                batch_minade3 = minade3.mean()
                batch_minfde3 = minfde3.mean()
                batch_avgade3 = avgade3.mean()
                batch_avgfde3 = avgfde3.mean()


                batch_loss = batch_minade3
                epoch_loss += batch_loss.item()
                batch_qloss = torch.zeros(1)
                batch_ploss = torch.zeros(1)

                print("Working on test {:d}/{:d}, batch {:d}/{:d}... ".format(test_time_ + 1, test_times, b + 1,
                                                                              len(data_loader)), end='\r')  # +

                epoch_ploss += batch_ploss.item() * batch_size
                epoch_qloss += batch_qloss.item() * batch_size
                epoch_minade2 += batch_minade2.item() * num_agents2
                epoch_avgade2 += batch_avgade2.item() * num_agents2
                epoch_minfde2 += batch_minfde2.item() * num_agents2
                epoch_avgfde2 += batch_avgfde2.item() * num_agents2
                epoch_minade3 += batch_minade3.item() * num_agents3
                epoch_avgade3 += batch_avgade3.item() * num_agents3
                epoch_minfde3 += batch_minfde3.item() * num_agents3
                epoch_avgfde3 += batch_avgfde3.item() * num_agents3

                epoch_minmsd += batch_minmsd.item() * num_agents3
                epoch_avgmsd += batch_avgmsd.item() * num_agents3

                epoch_agents += num_agents
                epoch_agents2 += num_agents2
                epoch_agents3 += num_agents3

                map_files = map_file(scene_id)
                output_files = [out_dir + '/' + x[2] + '_' + x[3] + '.jpg' for x in scene_id]

                cum_num_tgt_trajs = [0] + torch.cumsum(num_tgt_trajs, dim=0).tolist()
                cum_num_src_trajs = [0] + torch.cumsum(num_src_trajs, dim=0).tolist()

                src_trajs = src_trajs.cpu().numpy()
                src_lens = src_lens.cpu().numpy()

                tgt_trajs = tgt_trajs.cpu().numpy()
                tgt_lens = tgt_lens.cpu().numpy()

                zero_ind = np.nonzero(tgt_three_mask.numpy() == 0)[0]
                zero_ind -= np.arange(len(zero_ind))

                tgt_three_mask = tgt_three_mask.numpy()
                agent_tgt_three_mask = agent_tgt_three_mask.cpu().numpy()

                gen_trajs = gen_trajs.cpu().numpy()

                src_mask = agent_tgt_three_mask

                gen_trajs = np.insert(gen_trajs, zero_ind, 0, axis=0)

                tgt_trajs = np.insert(tgt_trajs, zero_ind, 0, axis=0)
                tgt_lens = np.insert(tgt_lens, zero_ind, 0, axis=0)

                for i in range(1):
                    candidate_i = gen_trajs[cum_num_tgt_trajs[i]:cum_num_tgt_trajs[i + 1]]
                    tgt_traj_i = tgt_trajs[cum_num_tgt_trajs[i]:cum_num_tgt_trajs[i + 1]]
                    tgt_lens_i = tgt_lens[cum_num_tgt_trajs[i]:cum_num_tgt_trajs[i + 1]]

                    src_traj_i = src_trajs[cum_num_src_trajs[i]:cum_num_src_trajs[i + 1]]
                    src_lens_i = src_lens[cum_num_src_trajs[i]:cum_num_src_trajs[i + 1]]
                    map_file_i = map_files[i]
                    output_file_i = output_files[i]

                    candidate_i = candidate_i[tgt_three_mask[cum_num_tgt_trajs[i]:cum_num_tgt_trajs[i + 1]]]
                    tgt_traj_i = tgt_traj_i[tgt_three_mask[cum_num_tgt_trajs[i]:cum_num_tgt_trajs[i + 1]]]
                    tgt_lens_i = tgt_lens_i[tgt_three_mask[cum_num_tgt_trajs[i]:cum_num_tgt_trajs[i + 1]]]

                    src_traj_i = src_traj_i[agent_tgt_three_mask[cum_num_src_trajs[i]:cum_num_src_trajs[i + 1]]]
                    src_lens_i = src_lens_i[agent_tgt_three_mask[cum_num_src_trajs[i]:cum_num_src_trajs[i + 1]]]

                    dao_i, dao_mask_i = dao(candidate_i, map_file_i)
                    dac_i, dac_mask_i = dac(candidate_i, map_file_i)

                    epoch_dao += dao_i.sum()
                    dao_agents += dao_mask_i.sum()

                    epoch_dac += dac_i.sum()
                    dac_agents += dac_mask_i.sum()

                    write_img_output(candidate_i, src_traj_i, src_lens_i, tgt_traj_i, tgt_lens_i, map_file_i,
                                     'test/img')
            print(1)



        list_loss.append(epoch_loss / epoch_agents)

        # 2-Loss
        list_minade2.append(epoch_minade2 / epoch_agents2)
        list_avgade2.append(epoch_avgade2 / epoch_agents2)
        list_minfde2.append(epoch_minfde2 / epoch_agents2)
        list_avgfde2.append(epoch_avgfde2 / epoch_agents2)

        # 3-Loss
        list_minade3.append(epoch_minade3 / epoch_agents3)
        list_avgade3.append(epoch_avgade3 / epoch_agents3)
        list_minfde3.append(epoch_minfde3 / epoch_agents3)
        list_avgfde3.append(epoch_avgfde3 / epoch_agents3)

        list_minmsd.append(epoch_minmsd / epoch_agents3)
        list_avgmsd.append(epoch_avgmsd / epoch_agents3)

        list_dao.append(epoch_dao / dao_agents)
        list_dac.append(epoch_dac / dac_agents)


    test_ploss = [0.0, 0.0]
    test_qloss = [0.0, 0.0]
    test_loss = [np.mean(list_loss), np.std(list_loss)]

    test_minade2 = [np.mean(list_minade2), np.std(list_minade2)]
    test_avgade2 = [np.mean(list_avgade2), np.std(list_avgade2)]
    test_minfde2 = [np.mean(list_minfde2), np.std(list_minfde2)]
    test_avgfde2 = [np.mean(list_avgfde2), np.std(list_avgfde2)]

    test_minade3 = [np.mean(list_minade3), np.std(list_minade3)]
    test_avgade3 = [np.mean(list_avgade3), np.std(list_avgade3)]
    test_minfde3 = [np.mean(list_minfde3), np.std(list_minfde3)]
    test_avgfde3 = [np.mean(list_avgfde3), np.std(list_avgfde3)]

    test_minmsd = [np.mean(list_minmsd), np.std(list_minmsd)]
    test_avgmsd = [np.mean(list_avgmsd), np.std(list_avgmsd)]

    test_dao = [np.mean(list_dao), np.std(list_dao)]
    test_dac = [np.mean(list_dac), np.std(list_dac)]

    test_ades = (test_minade2, test_avgade2, test_minade3, test_avgade3)
    test_fdes = (test_minfde2, test_avgfde2, test_minfde3, test_avgfde3)

    print("--Final Performane Report--")
    print("minADE3: {:.5f}±{:.5f}, minFDE3: {:.5f}±{:.5f}".format(test_minade3[0], test_minade3[1], test_minfde3[0],
                                                                  test_minfde3[1]))
    print("avgADE3: {:.5f}±{:.5f}, avgFDE3: {:.5f}±{:.5f}".format(test_avgade3[0], test_avgade3[1], test_avgfde3[0],
                                                                  test_avgfde3[1]))
    print("DAO: {:.5f}±{:.5f}, DAC: {:.5f}±{:.5f}".format(test_dao[0] * 10000.0, test_dao[1] * 10000.0, test_dac[0],
                                                          test_dac[1]))
    with open(out_dir + '/metric.pkl', 'wb') as f:
        pkl.dump({"ADEs": test_ades,
                  "FDEs": test_fdes,
                  "Qloss": test_qloss,
                  "Ploss": test_ploss,
                  "DAO": test_dao,
                  "DAC": test_dac}, f)
Пример #42
0
 def _sliding_window_processor(engine, batch):
     net.eval()
     with torch.no_grad():
         val_images, val_labels = batch[0].to(device), batch[1].to(device)
         seg_probs = sliding_window_inference(val_images, roi_size, sw_batch_size, net)
         return seg_probs, val_labels
Пример #43
0
def validate(val_loader, model, criterion):
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    adv_top1 = AverageMeter()
    adv_top5 = AverageMeter()

    # switch to evaluate mode
    model.eval()

    with torch.no_grad():
        end = time.time()
        for i, (input, target) in enumerate(val_loader):
            if args.gpu is not None:
                input = input.cuda(args.gpu, non_blocking=True)
            target = target.cuda(args.gpu, non_blocking=True)

            # compute output
            output = model(input)
            loss = criterion(output, target)

            # measure accuracy and record loss
            prec1, prec5 = accuracy(output, target, topk=(1, 5))
            losses.update(loss.item(), input.size(0))
            top1.update(prec1[0], input.size(0))
            top5.update(prec5[0], input.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % 100 == 0:
                logger.info('Test: [{0}/{1}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                      'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                       i, len(val_loader), batch_time=batch_time, loss=losses,
                       top1=top1, top5=top5))


        nonzero = total = 0
        filter_count = filter_total = 0
        total_sparsity = total_layer = 0
        for name, p in model.named_parameters():
            if 'weight' in name and len(list(p.size()))>1:
                tensor = p.data.cpu().numpy()
                tensor = np.abs(tensor)
                nz_count = np.count_nonzero(tensor)
                total_params = np.prod(tensor.shape)
                nonzero += nz_count
                total += total_params
                
                if len(tensor.shape)==4:
                    dim0 = np.sum(np.sum(tensor, axis=0),axis=(1,2))
                    dim1 = np.sum(np.sum(tensor, axis=1),axis=(1,2))
                    nz_count0 = np.count_nonzero(dim0)
                    nz_count1 = np.count_nonzero(dim1)
                    filter_count += nz_count0*nz_count1
                    filter_total += len(dim0)*len(dim1)
                    total_sparsity += 1-(nz_count0*nz_count1)/(len(dim0)*len(dim1))
                    total_layer += 1
                if len(tensor.shape)==2:
                    dim0 = np.sum(tensor, axis=0)
                    dim1 = np.sum(tensor, axis=1)
                    nz_count0 = np.count_nonzero(dim0)
                    nz_count1 = np.count_nonzero(dim1)
                    filter_count += nz_count0*nz_count1
                    filter_total += len(dim0)*len(dim1)
                    total_sparsity += 1-(nz_count0*nz_count1)/(len(dim0)*len(dim1))
                    total_layer += 1
                
        elt_sparsity = (total-nonzero)/total
        input_sparsity = (filter_total-filter_count)/filter_total
        output_sparsity = total_sparsity/total_layer

        logger.info(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f} Sparsity elt: {elt_sparsity:.3f} str: {input_sparsity:.3f} str_avg: {output_sparsity:.3f}'
              .format(top1=top1, top5=top5, elt_sparsity=elt_sparsity, input_sparsity=input_sparsity, output_sparsity=output_sparsity))

    return top1.avg, top5.avg
Пример #44
0
def _validate(data_group, model, criterion, device):
    # Open source accelerate package!
    classerr = tnt.ClassErrorMeter(accuracy=True, topk=[1, 5])  # Remove top 5.
    losses = {'objective_loss': tnt.AverageValueMeter()}
    """
    if _is_earlyexit(args):
        # for Early Exit, we have a list of errors and losses for each of the exits.
        args.exiterrors = []
        args.losses_exits = []
        for exitnum in range(args.num_exits):
            args.exiterrors.append(tnt.ClassErrorMeter(accuracy=True, topk=(1, 5)))
            args.losses_exits.append(tnt.AverageValueMeter())
        args.exit_taken = [0] * args.num_exits
    """
    batch_time = tnt.AverageValueMeter()
    total_samples = len(dataloaders[data_group].sampler)
    batch_size = dataloaders[data_group].batch_size
    total_steps = total_samples / batch_size
    # Display confusion option should be implmented in the near future.
    """
    if args.display_confusion:
        confusion = tnt.ConfusionMeter(args.num_classes
    """

    # Turn into evaluation model.
    model.eval()
    end = time.time()
    # Starting primiary teating code here.
    with torch.no_grad():
        for validation_step, data in enumerate(dataloaders[data_group]):
            inputs = data[0].to(device)
            labels = data[1].to(device)
            output = model(inputs)

            # Neglect elary exist mode in the first version.
            '''
            if not _is_earlyexit(args):
                # compute loss
                loss = criterion(output, target)
                # measure accuracy and record loss
                losses['objective_loss'].add(loss.item())
                classerr.add(output.detach(), target)
                if args.display_confusion:
                    confusion.add(output.detach(), target)
            else:
                earlyexit_validate_loss(output, target, criterion, args)
            '''

            loss = criterion(output, labels)
            losses['objective_loss'].add(loss.item())
            classerr.add(output.detach(), labels)
            steps_completed = (validation_step + 1)

            batch_time.add(time.time() - end)
            end = time.time()
            steps_completed = (validation_step + 1)
            #Record log using _log_validation_progress function
            # "\033[0;37;40m\tExample\033[0m"
            if steps_completed % 50 == 0:
                print('Test [{:5d}/{:5d}] \033[0;37;41mLoss {:.5f}\033[0'
                      '\033[0;37;42m\tTop1 {:.5f}  Top5 {:.5f}\033[m'
                      '\tTime {:.5f}.'.format(steps_completed,
                                              int(total_steps),
                                              losses['objective_loss'].mean,
                                              classerr.value(1),
                                              classerr.value(5),
                                              batch_time.mean))

        print('==> \033[0;37;42mTop1 {:.5f}  Top5 {:.5f}\033[m'
              '\033[0;37;41m\tLoss: {:.5f}\n\033[m.'.format(
                  classerr.value(1), classerr.value(5),
                  losses['objective_loss'].mean))

    return classerr.value(1), classerr.value(5), losses['objective_loss'].mean
Пример #45
0
    def val(self):
        start = timer()

        resolution = self.cfg.getint('grid', 'resolution')
        grid_length = self.cfg.getfloat('grid', 'length')
        delta_s = self.cfg.getfloat('grid', 'length') / self.cfg.getint('grid', 'resolution')
        sigma_inp = self.cfg.getfloat('grid', 'sigma_inp')
        sigma_out = self.cfg.getfloat('grid', 'sigma_out')
        grid = torch.from_numpy(make_grid_np(delta_s, resolution)).to(self.device)



        out_env = self.cfg.getboolean('model', 'out_env')
        val_bs = self.cfg.getint('validate', 'batchsize')

        rot_mtxs = torch.from_numpy(rot_mtx_batch(val_bs)).to(self.device).float()
        rot_mtxs_transposed = torch.from_numpy(rot_mtx_batch(val_bs, transpose=True)).to(self.device).float()

        samples_inp = self.data.samples_val_inp
        pos_dict = {}
        for sample in samples_inp:
            for a in sample.atoms:
                pos_dict[a] = a.pos

        #generators = []
        #for n in range(0, self.cfg.getint('validate', 'n_gibbs')):
        #    generators.append(iter(Mol_Generator_AA(self.data, train=False, rand_rot=False)))
        #all_elems = list(g)


        try:
            self.generator.eval()
            self.critic.eval()

            for n in range(0, self.cfg.getint('validate', 'n_gibbs')):
                g = iter(Mol_Rec_Generator(self.data, train=False, res='inp', rand_rot=False))
                for d in g:
                    with torch.no_grad():
                        #batch = all_elems[ndx:min(ndx + val_bs, len(all_elems))]

                        inp_positions = np.array([d['positions']])
                        #inp_featvec = np.array([d['inp_intra_featvec']])

                        inp_positions = torch.matmul(torch.from_numpy(inp_positions).to(self.device).float(), rot_mtxs)
                        aa_grid = self.to_voxel(inp_positions, grid, sigma_inp)

                        #features = torch.from_numpy(inp_featvec[:, :, :, None, None, None]).to(self.device) * inp_blobbs[:, :, None, :, :, :]
                        #features = torch.sum(features, 1)

                        mol = d['mol']

                        elems = (d['featvec'], d['repl'])
                        elems = self.transpose(self.insert_dim(self.to_tensor(elems)))

                        energy_ndx = (d['bond_ndx'], d['angle_ndx'], d['dih_ndx'], d['lj_ndx'])
                        energy_ndx = self.repeat(self.to_tensor(energy_ndx), val_bs)

                        generated_atoms = []
                        for featvec, repl in zip(*elems):
                            features = torch.sum(aa_grid[:, :, None, :, :, :] * featvec[:, :, :, None, None, None], 1)

                            # generate fake atom
                            if self.z_dim != 0:
                                z = torch.empty(
                                    [features.shape[0], self.z_dim],
                                    dtype=torch.float32,
                                    device=self.device,
                                ).normal_()

                                fake_atom = self.generator(z, features)
                            else:
                                fake_atom = self.generator(features)
                            generated_atoms.append(fake_atom)

                            # update aa grids
                            aa_grid = torch.where(repl[:, :, None, None, None], aa_grid, fake_atom)

                        # generated_atoms = torch.stack(generated_atoms, dim=1)
                        generated_atoms = torch.cat(generated_atoms, dim=1)

                        coords = avg_blob(
                            generated_atoms,
                            res=self.cfg.getint('grid', 'resolution'),
                            width=self.cfg.getfloat('grid', 'length'),
                            sigma=self.cfg.getfloat('grid', 'sigma_out'),
                            device=self.device,
                        )

                        coords = torch.matmul(coords, rot_mtxs_transposed)
                        coords = torch.sum(coords, 0) / val_bs

                        #for positions, mol in zip(coords, mols):
                        positions = coords.detach().cpu().numpy()
                        positions = np.dot(positions, mol.rot_mat.T)
                        for pos, atom in zip(positions, mol.atoms):
                            atom.pos = pos + mol.com

                samples_dir = self.out.output_dir / "samples"
                samples_dir.mkdir(exist_ok=True)

                for sample in self.data.samples_val_inp:
                    sample.write_aa_gro_file(samples_dir / (sample.name + "_" +str(n) + ".gro"))

            #reset atom positions
            for sample in self.data.samples_val_inp:
                for a in sample.atoms:
                    a.pos = pos_dict[a]

        finally:
            self.generator.train()
            self.critic.train()
            print("validation took ", timer()-start, "secs")
Пример #46
0
def main():
    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    # create a temporary directory and 40 random image, mask paris
    tempdir = tempfile.mkdtemp()
    print(f"generating synthetic data to {tempdir} (this may take a while)")
    for i in range(40):
        im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1)

        n = nib.Nifti1Image(im, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz"))

        n = nib.Nifti1Image(seg, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz"))

    images = sorted(glob(os.path.join(tempdir, "im*.nii.gz")))
    segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz")))

    # define transforms for image and segmentation
    train_imtrans = Compose([
        ScaleIntensity(),
        AddChannel(),
        RandSpatialCrop((96, 96, 96), random_size=False),
        RandRotate90(prob=0.5, spatial_axes=(0, 2)),
        ToTensor(),
    ])
    train_segtrans = Compose([
        AddChannel(),
        RandSpatialCrop((96, 96, 96), random_size=False),
        RandRotate90(prob=0.5, spatial_axes=(0, 2)),
        ToTensor(),
    ])
    val_imtrans = Compose([ScaleIntensity(), AddChannel(), ToTensor()])
    val_segtrans = Compose([AddChannel(), ToTensor()])

    # define nifti dataset, data loader
    check_ds = NiftiDataset(images,
                            segs,
                            transform=train_imtrans,
                            seg_transform=train_segtrans)
    check_loader = DataLoader(check_ds,
                              batch_size=10,
                              num_workers=2,
                              pin_memory=torch.cuda.is_available())
    im, seg = monai.utils.misc.first(check_loader)
    print(im.shape, seg.shape)

    # create a training data loader
    train_ds = NiftiDataset(images[:20],
                            segs[:20],
                            transform=train_imtrans,
                            seg_transform=train_segtrans)
    train_loader = DataLoader(train_ds,
                              batch_size=4,
                              shuffle=True,
                              num_workers=8,
                              pin_memory=torch.cuda.is_available())
    # create a validation data loader
    val_ds = NiftiDataset(images[-20:],
                          segs[-20:],
                          transform=val_imtrans,
                          seg_transform=val_segtrans)
    val_loader = DataLoader(val_ds,
                            batch_size=1,
                            num_workers=4,
                            pin_memory=torch.cuda.is_available())
    dice_metric = DiceMetric(include_background=True,
                             to_onehot_y=False,
                             sigmoid=True,
                             reduction="mean")

    # create UNet, DiceLoss and Adam optimizer
    device = torch.device("cuda:0")
    model = monai.networks.nets.UNet(
        dimensions=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)
    loss_function = monai.losses.DiceLoss(sigmoid=True)
    optimizer = torch.optim.Adam(model.parameters(), 1e-3)

    # start a typical PyTorch training
    val_interval = 2
    best_metric = -1
    best_metric_epoch = -1
    epoch_loss_values = list()
    metric_values = list()
    writer = SummaryWriter()
    for epoch in range(5):
        print("-" * 10)
        print(f"epoch {epoch + 1}/{5}")
        model.train()
        epoch_loss = 0
        step = 0
        for batch_data in train_loader:
            step += 1
            inputs, labels = batch_data[0].to(device), batch_data[1].to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_function(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            epoch_len = len(train_ds) // train_loader.batch_size
            print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}")
            writer.add_scalar("train_loss", loss.item(),
                              epoch_len * epoch + step)
        epoch_loss /= step
        epoch_loss_values.append(epoch_loss)
        print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

        if (epoch + 1) % val_interval == 0:
            model.eval()
            with torch.no_grad():
                metric_sum = 0.0
                metric_count = 0
                val_images = None
                val_labels = None
                val_outputs = None
                for val_data in val_loader:
                    val_images, val_labels = val_data[0].to(
                        device), val_data[1].to(device)
                    roi_size = (96, 96, 96)
                    sw_batch_size = 4
                    val_outputs = sliding_window_inference(
                        val_images, roi_size, sw_batch_size, model)
                    value = dice_metric(y_pred=val_outputs, y=val_labels)
                    metric_count += len(value)
                    metric_sum += value.item() * len(value)
                metric = metric_sum / metric_count
                metric_values.append(metric)
                if metric > best_metric:
                    best_metric = metric
                    best_metric_epoch = epoch + 1
                    torch.save(model.state_dict(), "best_metric_model.pth")
                    print("saved new best metric model")
                print(
                    "current epoch: {} current mean dice: {:.4f} best mean dice: {:.4f} at epoch {}"
                    .format(epoch + 1, metric, best_metric, best_metric_epoch))
                writer.add_scalar("val_mean_dice", metric, epoch + 1)
                # plot the last model output as GIF image in TensorBoard with the corresponding image and label
                plot_2d_or_3d_image(val_images,
                                    epoch + 1,
                                    writer,
                                    index=0,
                                    tag="image")
                plot_2d_or_3d_image(val_labels,
                                    epoch + 1,
                                    writer,
                                    index=0,
                                    tag="label")
                plot_2d_or_3d_image(val_outputs,
                                    epoch + 1,
                                    writer,
                                    index=0,
                                    tag="output")
    shutil.rmtree(tempdir)
    print(
        f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}"
    )
    writer.close()
Пример #47
0
def validate(val_loader, model, criterions, args, mode='valid'):
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    # switch to evaluate mode
    model.eval()

    criterion, criterion_mse, _, _ = criterions

    end = time.time()
    with torch.no_grad():
        for i, (input, target, _) in enumerate(val_loader):
            sl = input.shape
            # print("val_loader input shape : ", sl)
            # print("val_loader target shape : ", target.shape)
            batch_size = sl[0]
            # target = target.cuda(async=True)
            target = target.cuda()
            input_var = torch.autograd.Variable(input)
            target_var = torch.autograd.Variable(target)

            # compute output
            output = model(input_var)
            softmax = torch.nn.LogSoftmax(dim=1)(output)
            # print("batch size :", batch_size)
            # print("output size ", output.shape)
            # print("target_var size ", target_var.shape)

            # output = output[:batch_size]
            # print("output[:batch_size] size ",output.shape )
            # print("target_var size " ,target_var.shape)
            loss = criterion(output, target_var) / float(batch_size)

            # measure accuracy and record loss
            prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
            losses.update(loss.item(), input.size(0))
            top1.update(prec1.item(), input.size(0))
            top5.update(prec5.item(), input.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.print_freq == 0:
                if mode == 'test':
                    print('Test: [{0}/{1}]\t'
                          'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                          'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                          'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                          'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                              i,
                              len(val_loader),
                              batch_time=batch_time,
                              loss=losses,
                              top1=top1,
                              top5=top5))
                else:
                    print('Valid: [{0}/{1}]\t'
                          'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                          'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                          'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                          'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                              i,
                              len(val_loader),
                              batch_time=batch_time,
                              loss=losses,
                              top1=top1,
                              top5=top5))

    print(
        ' ****** Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f} Loss {loss.avg:.3f} '
        .format(top1=top1, top5=top5, loss=losses))

    return top1.avg, losses.avg
Пример #48
0
def train_pi(label_loader,
             unlabel_loader,
             model,
             criterions,
             optimizer,
             epoch,
             args,
             weight_pi=20.0):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    losses_pi = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    weights_cl = AverageMeter()

    # switch to train mode
    model.train()

    criterion, criterion_mse, _, criterion_l1 = criterions

    end = time.time()

    label_iter = iter(label_loader)
    unlabel_iter = iter(unlabel_loader)
    len_iter = len(unlabel_iter)
    for i in range(len_iter):
        # set weights for the consistency loss
        weight_cl = cal_consistency_weight(epoch * len_iter + i,
                                           end_ep=(args.epochs // 2) *
                                           len_iter,
                                           end_w=1.0)

        try:
            input, target, input1 = next(label_iter)
        except StopIteration:
            label_iter = iter(label_loader)
            input, target, input1 = next(label_iter)
        input_ul, _, input1_ul = next(unlabel_iter)
        sl = input.shape
        su = input_ul.shape
        batch_size = sl[0] + su[0]
        # measure data loading time
        data_time.update(time.time() - end)
        # target = target.cuda(async=True)
        target = target.cuda()
        input_var = torch.autograd.Variable(input)
        input1_var = torch.autograd.Variable(input1)
        input_ul_var = torch.autograd.Variable(input_ul)
        input1_ul_var = torch.autograd.Variable(input1_ul)
        input_concat_var = torch.cat([input_var, input_ul_var])
        input1_concat_var = torch.cat([input1_var, input1_ul_var])

        target_var = torch.autograd.Variable(target)

        # compute output
        output = model(input_concat_var)
        with torch.no_grad():
            output1 = model(input1_concat_var)

        output_label = output[:sl[0]]
        #pred = F.softmax(output, 1) # consistency loss on logit is better
        #pred1 = F.softmax(output1, 1)
        loss_ce = criterion(output_label, target_var) / float(sl[0])
        loss_pi = criterion_mse(output, output1) / float(
            args.num_classes * batch_size)

        reg_l1 = cal_reg_l1(model, criterion_l1)

        loss = loss_ce + args.weight_l1 * reg_l1 + weight_cl * weight_pi * loss_pi

        # measure accuracy and record loss
        prec1, prec5 = accuracy(output_label.data, target, topk=(1, 5))
        losses.update(loss_ce.item(), input.size(0))
        losses_pi.update(loss_pi.item(), input.size(0))
        top1.update(prec1.item(), input.size(0))
        top5.update(prec5.item(), input.size(0))
        weights_cl.update(weight_cl, input.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'LossPi {loss_pi.val:.4f} ({loss_pi.avg:.4f})\t'
                  'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                  'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                      epoch,
                      i,
                      len_iter,
                      batch_time=batch_time,
                      data_time=data_time,
                      loss=losses,
                      loss_pi=losses_pi,
                      top1=top1,
                      top5=top5))

    return top1.avg, losses.avg, losses_pi.avg, weights_cl.avg
Пример #49
0
def _create_identity_grid(size):
    with torch.no_grad():
        id_theta = torch.cuda.FloatTensor([[[1,0,0],[0,1,0]]]) # identity affine transform
        I = F.affine_grid(id_theta,torch.Size((1,1,size,size)))
        I *= (size - 1) / size # rescale the identity provided by PyTorch
        return I
Пример #50
0
def train_mt(label_loader,
             unlabel_loader,
             model,
             model_teacher,
             criterions,
             optimizer,
             epoch,
             args,
             ema_const=0.95,
             weight_mt=8.0):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    losses_cl = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    top1_t = AverageMeter()
    top5_t = AverageMeter()
    weights_cl = AverageMeter()

    # switch to train mode
    model.train()
    model_teacher.train()

    criterion, criterion_mse, _, criterion_l1 = criterions

    end = time.time()

    label_iter = iter(label_loader)
    unlabel_iter = iter(unlabel_loader)
    len_iter = len(unlabel_iter)
    for i in range(len_iter):
        # set weights for the consistency loss
        global_step = epoch * len_iter + i
        weight_cl = cal_consistency_weight(global_step,
                                           end_ep=(args.epochs // 2) *
                                           len_iter,
                                           end_w=1.0)

        try:
            input, target, input1 = next(label_iter)
        except StopIteration:
            label_iter = iter(label_loader)
            input, target, input1 = next(label_iter)
        input_ul, _, input1_ul = next(unlabel_iter)
        # print(f"input shape : {input.shape}")
        # print(f"target shape : {target.shape}")
        # print(f"input1 shape : {input1.shape}")
        sl = input.shape
        su = input_ul.shape

        # print("train_loader label input shape : ", sl)
        # print("train_loader unlabel input shape : ", su)
        batch_size = sl[0] + su[0]
        # measure data loading time
        data_time.update(time.time() - end)
        # target = target.cuda(async=True)
        target = target.cuda()
        input_var = torch.autograd.Variable(input)
        input1_var = torch.autograd.Variable(input1)
        input_ul_var = torch.autograd.Variable(input_ul)
        input1_ul_var = torch.autograd.Variable(input1_ul)
        input_concat_var = torch.cat([input_var, input_ul_var])
        input1_concat_var = torch.cat([input1_var, input1_ul_var])

        target_var = torch.autograd.Variable(target)

        # print(f"input_concat_var shape : {input_concat_var.shape}")
        # print(f"input1_concat_var shape : {input1_concat_var.shape}")
        # compute output
        output = model(input_concat_var)
        # print(f"output size  of input_concat_var : {output.shape}")
        with torch.no_grad():
            output1 = model_teacher(input1_concat_var)
            # print(f"output_1 size of input_concat_var shape : {output1.shape}")

        output_label = output[:sl[0]]
        output1_label = output1[:sl[0]]

        # print(f"sl[0] (input shape[0]) : {sl[0]}")
        # print(f"shape of output_label : {output_label.shape}")
        # print(f"shape of target_var : {target_var}")
        #pred = F.softmax(output, 1)
        #pred1 = F.softmax(output1, 1)
        loss_ce = criterion(output_label, target_var) / float(sl[0])
        loss_cl = criterion_mse(output, output1) / float(
            args.num_classes * batch_size)

        reg_l1 = cal_reg_l1(model, criterion_l1)

        loss = loss_ce + args.weight_l1 * reg_l1 + weight_cl * weight_mt * loss_cl

        # measure accuracy and record loss
        prec1, prec5 = accuracy(output_label.data, target, topk=(1, 5))
        prec1_t, prec5_t = accuracy(output1_label.data, target, topk=(1, 5))
        losses.update(loss_ce.item(), input.size(0))
        losses_cl.update(loss_cl.item(), input.size(0))
        top1.update(prec1.item(), input.size(0))
        top5.update(prec5.item(), input.size(0))
        top1_t.update(prec1_t.item(), input.size(0))
        top5_t.update(prec5_t.item(), input.size(0))
        weights_cl.update(weight_cl, input.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        update_ema_variables(model, model_teacher, ema_const, global_step)

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'LossCL {loss_cl.val:.4f} ({loss_cl.avg:.4f})\t'
                  'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                  'Prec@5 {top5.val:.3f} ({top5.avg:.3f})\t'
                  'PrecT@1 {top1_t.val:.3f} ({top1_t.avg:.3f})\t'
                  'PrecT@5 {top5_t.val:.3f} ({top5_t.avg:.3f})'.format(
                      epoch,
                      i,
                      len_iter,
                      batch_time=batch_time,
                      data_time=data_time,
                      loss=losses,
                      loss_cl=losses_cl,
                      top1=top1,
                      top5=top5,
                      top1_t=top1_t,
                      top5_t=top5_t))

    return top1.avg, losses.avg, losses_cl.avg, top1_t.avg, weights_cl.avg
Пример #51
0
def train_loop(model, optimizer, scheduler, start_epoch, end_epoch, l1_weight, device, train_data_loader, train_df, do_testing,
                test_every_n, test_data_loader, test_df, model_path, results_dd):
    """This is the main training loop

    Parameters
    ----------
    model : torch.nn.Model
        Model to train.
    optimizer : torch.optimizer
        Optimization algorithm.
    scheduler : torch.lr_scheduler
        Learning rate scheduler.
    epochs : int
        Number of full passes over the training dataset.
    device : torch.device
        Determines whether a GPU is used.
    train_data_loader : torch.DataLoader
        Object which iterates over train dataset.
    train_df : str
        Filepath to save intermediate training results.
    do_testing : bool
        Whether to test the model throughout training.
    test_every_n : int
        How often to do said testing.
    test_data_loader : torch.DataLoader
        iterates over test dataset.
    test_df : str
        Filepath to save intermediate test results.
    model_path : str
        Filepath (formattable with epoch number) to save model.

    Returns
    -------
    model
        Trained model.
    """

    train_dd = {}
    test_dd = {}
    logging.info("Beginning training for {} epochs".format(end_epoch - start_epoch))

    model.train()
    t0_train = default_timer()
    for ep in range(start_epoch, end_epoch):
        # model.train()
        t1 = default_timer()
        train_mse = 0
        train_l2 = 0
        for x, y, t_actual in train_data_loader:
            optimizer.zero_grad()

            x, y = x.to(device), y.to(device)
            t_actual = t_actual.to(device)
           # t_dummy = t_dummy.to(device)


           # t_actual_sqrt = torch.sqrt(t_actual)
            out = model(x) #,t_dummy)
            t_actual_sqrt = torch.sqrt(t_actual).view(-1,1).repeat(1,out.shape[1])
            #print(t_actual_sqrt.shape, out.shape)
           # for j in range(t_actual_sqrt.shape[0]):
            #        print(t_actual_sqrt[j])
          #  print(t_actual_sqrt.shape, out.shape)
            out = torch.div(out, t_actual_sqrt)

            mse = MSE(out, y)
        #    l1_nrm = torch.tensor(0.)
        #    for p in model.parameters():
        #        l1_nrm += p.abs().sum()
        #    loss = mse + l1_weight * l1_nrm
            mse.backward()
        #    loss.backward()
            optimizer.step()

            train_mse += mse.item()

        scheduler.step()
        # model.eval()

        train_mse /= len(train_data_loader)

        t2 = default_timer()
        logging.info("Epoch: {}, time: {:.2f}, train_mse: {:.4f}".format(ep, t2-t1, train_mse))
        train_dd['epoch'] = ep
        train_dd['MSE'] = train_mse
        train_dd['time'] = t2-t1
        write_result_to_file(train_df, **train_dd)

        ########################################################
        # Intermediate testing and saving
        ########################################################
        if ep % test_every_n == 0:
            test_mse = 0.
            test_l2_norm_error = 0.
            if do_testing:
                model.eval()
                with torch.no_grad():
                    for x, y, t_actual in test_data_loader:
                        x, y = x.to(device), y.to(device)
                        t_actual = t_actual.to(device)
                      #  t_dummy = t_dummy.to(device)


                        t_actual_sqrt = torch.sqrt(t_actual).view(-1,1).repeat(1, out.shape[1])
                        out = model(x) #,t_dummy)
                        out = torch.div(out, t_actual_sqrt)

                        mse = MSE(out, y)
                        test_mse += mse.item()

                        l2_err = l2_normalized_error(out, y)
                        test_l2_norm_error += l2_err.item()
                model.train()

                test_mse /= len(test_data_loader)
                test_l2_norm_error /= len(test_data_loader)

                test_dd['test_mse'] = test_mse
                test_dd['test_l2_normalized_error'] = test_l2_norm_error
                test_dd['epoch'] = ep

                write_result_to_file(test_df, **test_dd)
                logging.info("Test: Epoch: {}, test_mse: {:.4f}".format(ep, test_mse))
            torch.save(model, model_path.format(ep))

    torch.save(model, model_path.format(end_epoch))
    if end_epoch - start_epoch > 0:
        results_dd['train_mse'] = train_mse
        results_dd['test_mse'] = test_mse
    return model
Пример #52
0
def export_to_pb(model, inputs, *args, **kwargs):
    f = io.BytesIO()
    with torch.no_grad():
        torch.onnx.export(model, inputs, f, *args, **kwargs)
    return f.getvalue()
Пример #53
0
def training_loop(hyperparameters):
    print(f"Starting training with hyperparameters: {hyperparameters}")
    save_path = hyperparameters["save_path"]
    load_path = hyperparameters["load_path"]

    # create the save path and save hyperparameter configuration
    if not os.path.exists(save_path):
        os.mkdir(save_path)
    else:
        a = input("Warning, Directory already exists. Dou want to continue?")
        if a not in ["Y","y"]:
            raise Exception("Path already exists, please start with another path.")

    with open(save_path+ "/parameters.json", "w") as f:
        json.dump(hyperparameters, f)

    # general configurations
    state_dim=18
    action_dim=4
    max_action=1
    iterations=hyperparameters["max_iterations"]
    batch_size=hyperparameters["batch_size"]
    max_episodes=hyperparameters["max_episodes"]
    train_mode = hyperparameters["train_mode"]
    closeness_factor=hyperparameters["closeness_factor"]
    c = closeness_factor

    # init the agent
    agent1 = TD3Agent([state_dim + action_dim, 256, 256, 1],
                        [state_dim, 256, 256, action_dim],
                        optimizer=hyperparameters["optimizer"],
                        policy_noise=hyperparameters["policy_noise"],
                        policy_noise_clip=hyperparameters["policy_noise_clip"],
                        gamma=hyperparameters["gamma"],
                        delay=hyperparameters["delay"],
                        tau=hyperparameters["tau"],
                        lr=hyperparameters["lr"],
                        max_action=max_action,
                        weight_decay=hyperparameters["weight_decay"])

    # load the agent if given
    loaded_state=False
    if load_path:
        agent1.load(load_path)
        loaded_state=True

    # define opponent
    if hyperparameters["self_play"]:
        agent2=agent1
    else:
        agent2 = h_env.BasicOpponent(weak=hyperparameters["weak_agent"])

    # load enviroment and replaybuffer
    replay_buffer = ReplayBuffer(state_dim, action_dim)

    if train_mode == "defense":
        env = h_env.HockeyEnv(mode=h_env.HockeyEnv.TRAIN_DEFENSE)
    elif train_mode == "shooting":
        env = h_env.HockeyEnv(mode=h_env.HockeyEnv.TRAIN_SHOOTING)
    else:
        env = h_env.HockeyEnv()


    # add figure to plot later
    if hyperparameters["plot_performance"]:
        fig, (ax_loss, ax_reward) = plt.subplots(2)
        ax_loss.set_xlim(0, max_episodes)
        ax_loss.set_ylim(0, 20)
        ax_reward.set_xlim(0, max_episodes)
        ax_reward.set_ylim(-30, 20)

    with HiddenPrints():
    # first sample enough data to start:
        obs_last = env.reset()
        for i in range(batch_size*100):
            a1 = env.action_space.sample()[:4] if not loaded_state else agent1.act(env.obs_agent_two())
            a2 = agent2.act(env.obs_agent_two())
            obs, r, d, info = env.step(np.hstack([a1,a2]))
            done = 1 if d else 0
            replay_buffer.add(obs_last, a1, obs, r, done)
            obs_last=obs
            if d:
                obs_last = env.reset()

    print("Finished collection of data prior to training")

    # tracking of performance
    episode_critic_loss=[]
    episode_rewards=[]
    win_count=[]
    if not os.path.isfile(save_path + "/performance.csv"):
        pd.DataFrame(data={"Episode_rewards":[], "Episode_critic_loss":[], "Win/Loss":[]}).to_csv(save_path + "/performance.csv", sep=",", index=False)

    # Then start training
    for episode_count in range(max_episodes+1):
        obs_last = env.reset()
        total_reward=0
        critic_loss=[]

        for i in range(iterations):
            # run the enviroment
            with HiddenPrints():
                with torch.no_grad():
                    a1 =  agent1.act(env.obs_agent_two()) + np.random.normal(loc=0, scale=hyperparameters["exploration_noise"], size=action_dim)
                a2 = agent2.act(env.obs_agent_two())
                obs, r, d, info = env.step(np.hstack([a1,a2]))
            total_reward+=r
            done = 1 if d else 0

            # mopify reward with cloeness to puck reward
            if hyperparameters["closeness_decay"]:
                c = closeness_factor *(1 - episode_count/max_episodes)
            newreward = r + c * info["reward_closeness_to_puck"] 

            # add to replaybuffer
            replay_buffer.add(obs_last, a1, obs, newreward, done)
            obs_last=obs
            
            # sample minibatch and train
            states, actions, next_states, reward, done = replay_buffer.sample(batch_size)
            loss = agent1.train(states, actions, next_states, reward, done)
            critic_loss.append(loss.detach().numpy())

            # if done, finish episode
            if d:
                episode_rewards.append(total_reward)
                episode_critic_loss.append(np.mean(critic_loss))
                win_count.append(info["winner"])
                print(f"Episode {episode_count} finished after {i} steps with a total reward of {total_reward}")
                
                # Online plotting
                if hyperparameters["plot_performance"] and episode_count>40 :
                    ax_loss.plot(list(range(-1, episode_count-29)), moving_average(episode_critic_loss, 30), 'r-')
                    ax_reward.plot(list(range(-1, episode_count-29)), moving_average(episode_rewards, 30), "r-")
                    plt.draw()
                    plt.pause(1e-17)

                break
        
        # Intermediate evaluation of win/loss and saving of model
        if episode_count % 500 ==0 and episode_count != 0:
                print(f"The agents win ratio in the last 500 episodes was {win_count[-500:].count(1)/500}")
                print(f"The agents loose ratio in the last 500 episodes was {win_count[-500:].count(-1)/500}")
                try:
                    agent1.save(save_path)
                    print("saved model")
                except Exception:
                    print("Saving Failed model failed")
                pd.DataFrame(data={"Episode_rewards": episode_rewards[-500:], "Episode_critic_loss": episode_critic_loss[-500:], "Win/Loss": win_count[-500:]}).to_csv(save_path + "/performance.csv", sep=",", index=False, mode="a", header=False)
                    
    print(f"Finished training with a final mean reward of {np.mean(episode_rewards[-500:])}")





    # plot the performance summary
    if hyperparameters["plot_performance_summary"]:
            try:
                fig, (ax1, ax2) = plt.subplots(2)
                x = list(range(len(episode_critic_loss)))
                coef = np.polyfit(x, episode_critic_loss,1)
                poly1d_fn = np.poly1d(coef)
                ax1.plot(episode_critic_loss)
                ax1.plot(poly1d_fn(list(range(len(episode_critic_loss)))))


                x = list(range(len(episode_rewards)))
                coef = np.polyfit(x, episode_rewards,1)
                poly1d_fn = np.poly1d(coef)
                ax2.plot(episode_rewards)
                ax2.plot(poly1d_fn(list(range(len(episode_rewards)))))
                fig.show()
                fig.savefig(save_path + "/performance.png", bbox_inches="tight")
            except:
                print("Failed saving figure")
Пример #54
0
def main(args):
    # log files
    log_dir = '{}{:04d}'.format(args.log_dir, args.exp_id)
    writer = SummaryWriter(log_dir=log_dir)
    # write hyper params to file
    args_dict = vars(args)
    arg_file = open(log_dir + '/args.txt', 'w')
    for arg_key in args_dict.keys():
        arg_file.write(arg_key + " = {}\n".format(args_dict[arg_key]))
    arg_file.close()

    # dataset
    train_dataset = TrainKitchenDataset(path=args.data_path,
                                        delta=args.use_delta,
                                        normalize=args.normalize,
                                        traj_len=args.traj_len)

    train_dataset_loader = DataLoader(train_dataset,
                                      batch_size=args.batch_size,
                                      shuffle=True,
                                      num_workers=4)
    test_dataset = TestKitchenDataset(path=args.data_path,
                                      delta=args.use_delta,
                                      normalize=args.normalize,
                                      traj_len=args.traj_len)
    test_dataset_loader = DataLoader(test_dataset,
                                     batch_size=args.batch_size,
                                     shuffle=False,
                                     num_workers=4)

    # define policy
    policy = FCNetwork(obs_dim=2 * args.observation_dim,
                       act_dim=args.action_dim,
                       hidden_sizes=(32, 32))
    loss_criterion = torch.nn.MSELoss()
    optimizer = torch.optim.Adam(list(policy.parameters()), lr=args.lr)

    #
    iter_num = 0
    best_test_loss = 1e10
    for e in range(args.num_epoch):
        policy.train()
        # training loop
        for j, data in enumerate(train_dataset_loader):
            state, gt_act = data['state'], data['act']
            optimizer.zero_grad()
            pred_act = policy(state)
            loss = loss_criterion(pred_act, gt_act)
            loss.backward()
            optimizer.step()

            # log the values, basically loss
            iter_num += 1
            writer.add_scalar('train/loss', loss.item(), iter_num)
            print("Train_loss = {:.3f} iter = {:06d}".format(
                loss.item(), iter_num))
            # model saving code as well
            if iter_num % args.save_iter == 0:
                torch.save(policy.state_dict(),
                           log_dir + '/model_{}.pth'.format(iter_num))

        # TODO: Need to figure this out
        '''
        params_after_opt = bc_agent.policy.get_param_values()
        bc_agent.policy.set_param_values(params_after_opt, set_new=True, set_old=True)
        '''
        # testing loop
        policy.eval()
        test_loss_list = []
        with torch.no_grad():
            for j, data in enumerate(test_dataset_loader):
                state, gt_act = data['state'], data['act']
                pred_act = policy(state)
                loss = loss_criterion(pred_act, gt_act)
                test_loss_list.append(loss.item())
        test_loss = np.array(test_loss_list).mean()
        writer.add_scalar('test/loss', test_loss, iter_num)
        print(
            colored("Test_loss = {:.3f} epoch = {:06d}".format(test_loss, e),
                    'red'))
        # save model
        if test_loss < best_test_loss:
            torch.save(policy.state_dict(),
                       log_dir + '/best_loss.pth'.format(iter_num))
            best_test_loss = test_loss
Пример #55
0
def closure():

    global i, net_input, psnr_max, psnr_noisy_max, files_name, psnr_2_max, noisy_np
    global TRAIN_PLAN, noisy_np_norm, sigma_now, final_ssim, final_ssim_max, files_name
    global psnr_curve_max_record, ssim_curve_max_record, training_loss_record

    out_effect_np = []
    if DATA_AUG:
        for aug in range(len(img_noisy_torch)):
            noisy_torch = np_to_torch(img_noisy_noisy_np[aug]-img_noisy_np[aug])
            out = net(net_input[aug])
            total_loss = mse(out, noisy_torch.type(dtype))

            total_loss.backward()
            psrn_noisy = compare_psnr(np.clip(img_noisy_np[aug], 0, 1), (torch_to_np(net_input[aug])-out.detach().cpu().numpy()[0]))
            # do_i_learned_noise = out.detach().cpu().numpy()[0]
            # mse_what_tf = MSE(noisy_np, do_i_learned_noise)

            if psnr_noisy_max == 0:
                psnr_noisy_max = psrn_noisy
            elif psnr_noisy_max < psrn_noisy:
                psnr_noisy_max = psrn_noisy

            if SAVE_DURING_TRAINING and i % save_every == 0:
                # output_dir
                out_test_np = torch_to_np(out)  # I +N1
                # out_test_name = f'{i}_test'
                # save_image(out_test_name, np.clip(out_test_np, 0, 1), output_path=output_dir)

                net.eval()
                loss_add = 0
                with torch.no_grad():
                    out_effect_np_ = torch_to_np(img_noisy_torch[aug]-net(img_noisy_torch[aug]))
                    out_effect_np.append(out_effect_np_)
                    psnr_1 = compare_psnr(img_aug_np[aug], np.clip(out_effect_np_, 0, 1))
                    test_do_i_learned_noise = torch_to_np(net(img_noisy_torch[aug]))

                    if psnr_max == 0:
                        psnr_max = psnr_1
                    elif psnr_max < psnr_1:
                        psnr_max = psnr_1

                    loss_add = loss_add + total_loss.item()

        training_loss_record.append(loss_add/len(img_noisy_torch))
        if i % 10 == 0:
            out_effect_np[0] = out_effect_np[0].transpose(1, 2, 0)
            for aug in range(1, 8):
                if aug < 4:
                   out_effect_np[aug] = np.rot90(out_effect_np[aug].transpose(1, 2, 0), 4-aug)
                else:
                    out_effect_np[aug] = np.flipud(np.rot90(out_effect_np[aug].transpose(1, 2, 0), 8-aug))
            final_reuslt = np.mean(out_effect_np, 0)

            psnr_2 = compare_psnr(img_aug_np[0].transpose(1, 2, 0), np.clip(final_reuslt, 0, 1))
            final_ssim = compare_ssim(img_aug_np[0].transpose(1, 2, 0), np.clip(final_reuslt, 0, 1), data_range=1, multichannel=True)

            if psnr_2_max == 0:
                psnr_2_max = psnr_2
                tmp_name_p = f'{files_name[:-4]}_{sigma_now * 255:.2f}_{psnr_2:.2f}_final_{final_ssim:.4f}'
                save_image(tmp_name_p, np.clip(final_reuslt.transpose(2, 0, 1), 0, 1), output_path=output_dir)
            elif psnr_2_max < psnr_2:
                psnr_2_max = psnr_2
                tmp_name_p = f'{files_name[:-4]}_{sigma_now * 255:.2f}_{psnr_2:.2f}_final_{final_ssim:.4f}'
                save_image(tmp_name_p, np.clip(final_reuslt.transpose(2, 0, 1), 0, 1), output_path=output_dir)
            if final_ssim_max == 0:
                final_ssim_max = final_ssim
            elif final_ssim_max < final_ssim:
                final_ssim_max = final_ssim
                tmp_name = f'{files_name[:-4]}_{sigma_now * 255:.2f}_{final_ssim:.4f}_final_{psnr_2:.2f}'
                save_image(tmp_name, np.clip(final_reuslt.transpose(2, 0, 1), 0, 1), output_path=output_dir)

            print('%s Iteration %05d ,psnr 2: %f, psnr 2 max: %f, final ssim : %f, final ssim max: %f'
                  % (files_name, i, psnr_2, psnr_2_max, final_ssim, final_ssim_max))
            writer.add_scalar('final_test_psnr', psnr_2, i)
            writer.add_scalar('final_max_test_psnr', psnr_2_max, i)
            psnr_curve_max_record.append(psnr_2_max)
            ssim_curve_max_record.append(final_ssim_max)

    else:
        noisy_torch = np_to_torch(img_noisy_noisy_np - img_noisy_np)
        out = net(net_input)
        total_loss = mse(out, noisy_torch.type(dtype))

        total_loss.backward()
        psrn_noisy = compare_psnr(np.clip(img_noisy_np, 0, 1), (torch_to_np(net_input) - out.detach().cpu().numpy()[0]))
        do_i_learned_noise = out.detach().cpu().numpy()[0]
        mse_what_tf = MSE(noisy_np, do_i_learned_noise)
        if psnr_noisy_max == 0:
            psnr_noisy_max = psrn_noisy
        elif psnr_noisy_max < psrn_noisy:
            psnr_noisy_max = psrn_noisy

        if SAVE_DURING_TRAINING and i % save_every == 0:
            # output_dir
            out_test_np = torch_to_np(out)  # I +N1
            # out_test_name = f'{i}_test'
            # save_image(out_test_name, np.clip(out_test_np, 0, 1), output_path=output_dir)

        net.eval()
        loss_add = 0
        with torch.no_grad():
            out_effect_np = torch_to_np(img_noisy_torch - net(img_noisy_torch))
            psnr_1 = compare_psnr(img_np, np.clip(out_effect_np, 0, 1))
            test_do_i_learned_noise = torch_to_np(net(img_noisy_torch))

            if psnr_max == 0:
                psnr_max = psnr_1
            elif psnr_max < psnr_1:
                psnr_max = psnr_1

            loss_add = loss_add + total_loss.item()

        training_loss_record.append(loss_add / len(img_noisy_torch))
        if i % 10 == 0:
            final_reuslt = out_effect_np.transpose(1, 2, 0)
            psnr_2 = compare_psnr(img_np.transpose(1, 2, 0), np.clip(final_reuslt, 0, 1))
            final_ssim = compare_ssim(img_np.transpose(1, 2, 0), np.clip(final_reuslt, 0, 1), data_range=1, multichannel=True)
            if psnr_2_max==0:
                psnr_2_max = psnr_2
                tmp_name_p = f'{files_name[:-4]}_{sigma_now*255:.2f}_{psnr_2:.2f}_final_{final_ssim:.4f}'
                save_image(tmp_name_p, np.clip(final_reuslt.transpose(2, 0, 1), 0, 1), output_path=output_dir)
            elif psnr_2_max< psnr_2:
                psnr_2_max = psnr_2
                tmp_name_p = f'{files_name[:-4]}_{sigma_now*255:.2f}_{psnr_2:.2f}_final_{final_ssim:.4f}'
                save_image(tmp_name_p, np.clip(final_reuslt.transpose(2, 0, 1), 0, 1), output_path=output_dir)
            if final_ssim_max==0:
                final_ssim_max = final_ssim
            elif final_ssim_max<final_ssim:
                final_ssim_max = final_ssim
                tmp_name = f'{files_name[:-4]}_{sigma_now*255:.2f}_{final_ssim:.4f}_final_{psnr_2:.2f}'
                save_image(tmp_name, np.clip(final_reuslt.transpose(2, 0, 1), 0, 1), output_path=output_dir)

            print('%s, sigma %f, Epoch %05d, psnr 2: %f, psnr 2 max: %f, final ssim : %f, final ssim max: %f'
                  %(files_name, sigma_now*255, i, psnr_2, psnr_2_max, final_ssim, final_ssim_max))
            writer.add_scalar('final_test_psnr', psnr_2, i)
            writer.add_scalar('final_max_test_psnr', psnr_2_max, i)
            psnr_curve_max_record.append(psnr_2_max)
            ssim_curve_max_record.append(final_ssim_max)

    i += 1

    return total_loss
Пример #56
0
    def predict(self):

        # Turn on evaluation for network
        self.cnn.model.eval()
        results = {"labels": [], "logits": [], "predictions": []}

        with torch.no_grad():
            for i, (spatial_data, temporal_data, label, metadata) in enumerate(
                tqdm(self.data_loader, desc="Prediction", leave=False, unit="sample")
            ):

                spatial_data = spatial_data.to(self.device)
                temporal_data = temporal_data.to(self.device)
                label = label.to(self.device)

                ape_id = metadata["ape_id"][0]
                start_frame = metadata["start_frame"][0]
                video = metadata["video"][0]

                logits = self.cnn.model(spatial_data, temporal_data)

                # Accumulate predictions against ground truth labels
                prediction = logits.argmax().item()

                # Logits are required for Top-k accuracy
                logits = (logits.detach().cpu()).numpy()

                # Collect reulsts for metrics
                results["labels"].append(label.item())
                results["logits"].append(logits[0].tolist())
                results["predictions"].append(prediction)

                # Insert results to dictionary
                if video not in self.predictions.keys():
                    self.predictions[video] = []

                self.predictions[video].append(
                    {
                        "ape_id": ape_id.item(),
                        "label": label.item(),
                        "prediction": prediction,
                        "start_frame": start_frame.item(),
                    }
                )

        # Get accuracy by checking for correct predictions across all predictions
        top1, top3 = metrics.compute_topk_accuracy(
            torch.LongTensor(results["logits"]), torch.LongTensor(results["labels"]), topk=(1, 3)
        )

        # Get per class accuracies and sort by label value (0...9)
        class_accuracy = metrics.compute_class_accuracy(results["labels"], results["predictions"])
        class_accuracy_average = mean(class_accuracy.values())

        print("==> Per Class Results")
        per_class_results = [
            [
                "Class",
                "camera_interaction",
                "climbing_down",
                "climbing_up",
                "hanging",
                "running",
                "sitting",
                "sitting_on_back",
                "standing",
                "walking",
            ],
            [
                "Accuracy",
                f"{class_accuracy[0]:2.2f}",
                f"{class_accuracy[1]:2.2f}",
                f"{class_accuracy[2]:2.2f}",
                f"{class_accuracy[3]:2.2f}",
                f"{class_accuracy[4]:2.2f}",
                f"{class_accuracy[5]:2.2f}",
                f"{class_accuracy[6]:2.2f}",
                f"{class_accuracy[7]:2.2f}",
                f"{class_accuracy[8]:2.2f}",
            ],
        ]

        print(tabulate(per_class_results, tablefmt="fancy_grid"))

        print("==> Overall Results")
        test_results = [
            ["Average Class Accuracy:", f"{class_accuracy_average:2f}"],
            ["Top1 Accuracy:", f"{top1.item():.2f}"],
            ["Top3 Accuracy:", f"{top3.item():.2f}"],
        ]

        print(tabulate(test_results, tablefmt="fancy_grid"))

        return self.predictions
Пример #57
0
def evaluate_model(model, test_loader, device, criterion):
    model.eval()
    i = 0
    precision = 0.0
    recall = 0.0
    test_loss = 0
    correct = 0
    error = 0
    with torch.no_grad():
        for sample_batched in test_loader:
            i += 1
            data, target = sample_batched['data'].to(
                device), sample_batched['label'].type(
                    torch.LongTensor).to(device)
            output, feature = model(data)
            pred = output.max(1, keepdim=True)[1]  # 返回两个,一个是最大值另一个是最大值索引
            img = torch.squeeze(pred).cpu().numpy() * 255
            lab = torch.squeeze(target).cpu().numpy() * 255
            img = img.astype(np.uint8)
            lab = lab.astype(np.uint8)
            kernel = np.uint8(np.ones((3, 3)))

            #accuracy
            test_loss += criterion(output, target).item()  # sum up batch loss
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).sum().item()

            #precision,recall,f1
            label_precision = cv2.dilate(lab, kernel)
            pred_recall = cv2.dilate(img, kernel)
            img = img.astype(np.int32)
            lab = lab.astype(np.int32)
            label_precision = label_precision.astype(np.int32)
            pred_recall = pred_recall.astype(np.int32)
            a = len(np.nonzero(img * label_precision)[1])
            b = len(np.nonzero(img)[1])
            if b == 0:
                error = error + 1
                continue
            else:
                precision += float(a / b)
            c = len(np.nonzero(pred_recall * lab)[1])
            d = len(np.nonzero(lab)[1])
            if d == 0:
                error = error + 1
                continue
            else:
                recall += float(c / d)
            F1_measure = (2 * precision * recall) / (precision + recall)

    test_loss /= (len(test_loader.dataset) / args.test_batch_size)
    test_acc = 100. * int(correct) / (len(test_loader.dataset) *
                                      config.label_height * config.label_width)
    print('\nAverage loss: {:.4f}, Accuracy: {}/{} ({:.5f}%)'.format(
        test_loss, int(correct), len(test_loader.dataset), test_acc))

    precision = precision / (len(test_loader.dataset) - error)
    recall = recall / (len(test_loader.dataset) - error)
    F1_measure = F1_measure / (len(test_loader.dataset) - error)
    print('Precision: {:.5f}, Recall: {:.5f}, F1_measure: {:.5f}\n'.format(
        precision, recall, F1_measure))
Пример #58
0
# -----------------------------------

logs_idx = f'emb_lr{emb_lr}-enc_lr{enc_lr}-dec_lr{dec_lr}-batch_size{batch_size}'
saves = glob.glob(f'logs/{logs_idx}/*.pt')

saves.sort(key=os.path.getmtime)
checkpoint = torch.load(saves[-1], )
text_embedding.load_state_dict(checkpoint['text_embedding'])
text_embedding.eval()
encoder.load_state_dict(checkpoint['encoder'])
encoder.eval()
decoder.load_state_dict(checkpoint['decoder'])
decoder.eval()

with torch.no_grad():
    text_data = parse_text('hello, this is just a test').to(device)
    text_data = text_data.unsqueeze(0)
    text_emb = text_embedding(text_data)

    text_pos = (torch.arange(text_data.size(1)) + 1).to(device)
    text_pos = text_pos.unsqueeze(0).to(device)
    text_pos_emb = pos_embedding_(text_pos)
    text_mask = (text_pos == 0).unsqueeze(1)
    enc_out, att_heads_enc = encoder(text_emb, text_mask, text_pos_emb)

    mel_pos = torch.arange(1, 512).view(1, 511).to(device)
    mel_pos_emb_ = pos_embedding(mel_pos)
    mel_mask_ = torch.triu(torch.ones(511, 511, dtype=torch.bool), 1).unsqueeze(0).to(device)
    # [B, T, C], [B, T, C], [B, T, 1], [B, T, T_text]
    mel = torch.zeros(1, 511, 80).to(device)
Пример #59
0
def infidelity(
    forward_func: Callable,
    perturb_func: Callable,
    inputs: TensorOrTupleOfTensorsGeneric,
    attributions: TensorOrTupleOfTensorsGeneric,
    baselines: BaselineType = None,
    additional_forward_args: Any = None,
    target: TargetType = None,
    n_perturb_samples: int = 10,
    max_examples_per_batch: int = None,
) -> Tensor:
    r"""
    Explanation infidelity represents the expected mean-squared error
    between the explanation multiplied by a meaningful input perturbation
    and the differences between the predictor function at its input
    and perturbed input.
    More details about the measure can be found in the following paper:
    https://arxiv.org/pdf/1901.09392.pdf

    It is derived from the completeness property of well-known attribution
    algorithms and is a computationally more efficient and generalized
    notion of Sensitivy-n. The latter measures correlations between the sum
    of the attributions and the differences of the predictor function at
    its input and fixed baseline. More details about the Sensitivity-n can
    be found here:
    https://arxiv.org/pdf/1711.06104.pdfs

    The users can perturb the inputs any desired way by providing any
    perturbation function that takes the inputs (and optionally baselines)
    and returns perturbed inputs or perturbed inputs and corresponding
    perturbations.

    This specific implementation is primarily tested for attribution-based
    explanation methods but the idea can be expanded to use for non
    attribution-based interpretability methods as well.

    Args:

        forward_func (callable):
                The forward function of the model or any modification of it.

        perturb_func (callable):
                The perturbation function of model inputs. This function takes
                model inputs and optionally baselines as input arguments and returns
                either a tuple of perturbations and perturbed inputs or just
                perturbed inputs. For example:

                >>> def my_perturb_func(inputs):
                >>>   <MY-LOGIC-HERE>
                >>>   return perturbations, perturbed_inputs

                If we want to only return perturbed inputs and compute
                perturbations internally then we can wrap perturb_func with
                `infidelity_perturb_func_decorator` decorator such as:

                >>> from captum.metrics import infidelity_perturb_func_decorator

                >>> @infidelity_perturb_func_decorator(<multipy_by_inputs flag>)
                >>> def my_perturb_func(inputs):
                >>>   <MY-LOGIC-HERE>
                >>>   return perturbed_inputs

                In case `multipy_by_inputs` is False we compute perturbations by
                `input - perturbed_input` difference and in case `multipy_by_inputs`
                flag is True we compute it by dividing
                (input - perturbed_input) by (input - baselines).
                The user needs to only return perturbed inputs in `perturb_func`
                as described above.

                `infidelity_perturb_func_decorator` needs to be used with
                `multipy_by_inputs` flag set to False in case infidelity
                score is being computed for attribution maps that are local aka
                that do not factor in inputs in the final attribution score.
                Such attribution algorithms include Saliency, GradCam, Guided Backprop,
                or Integrated Gradients and DeepLift attribution scores that are already
                computed with `multipy_by_inputs=False` flag.

                If there are more than one inputs passed to infidelity function those
                will be passed to `perturb_func` as tuples in the same order as they
                are passed to infidelity function.

                If inputs
                 - is a single tensor, the function needs to return a tuple
                   of perturbations and perturbed input such as:
                   perturb, perturbed_input and only perturbed_input in case
                   `infidelity_perturb_func_decorator` is used.
                 - is a tuple of tensors, corresponding perturbations and perturbed
                   inputs must be computed and returned as tuples in the
                   following format:

                   (perturb1, perturb2, ... perturbN), (perturbed_input1,
                   perturbed_input2, ... perturbed_inputN)

                   Similar to previous case here as well we need to return only
                   perturbed inputs in case `infidelity_perturb_func_decorator`
                   decorates out `perturb_func`.
                It is important to note that for performance reasons `perturb_func`
                isn't called for each example individually but on a batch of
                input examples that are repeated `max_examples_per_batch / batch_size`
                times within the batch.

        inputs (tensor or tuple of tensors):  Input for which
                attributions are computed. If forward_func takes a single
                tensor as input, a single input tensor should be provided.
                If forward_func takes multiple tensors as input, a tuple
                of the input tensors should be provided. It is assumed
                that for all given input tensors, dimension 0 corresponds
                to the number of examples (aka batch size), and if
                multiple input tensors are provided, the examples must
                be aligned appropriately.

        baselines (scalar, tensor, tuple of scalars or tensors, optional):
                Baselines define reference values which sometimes represent ablated
                values and are used to compare with the actual inputs to compute
                importance scores in attribution algorithms. They can be represented
                as:

                - a single tensor, if inputs is a single tensor, with
                  exactly the same dimensions as inputs or the first
                  dimension is one and the remaining dimensions match
                  with inputs.

                - a single scalar, if inputs is a single tensor, which will
                  be broadcasted for each input value in input tensor.

                - a tuple of tensors or scalars, the baseline corresponding
                  to each tensor in the inputs' tuple can be:

                - either a tensor with matching dimensions to
                  corresponding tensor in the inputs' tuple
                  or the first dimension is one and the remaining
                  dimensions match with the corresponding
                  input tensor.

                - or a scalar, corresponding to a tensor in the
                  inputs' tuple. This scalar value is broadcasted
                  for corresponding input tensor.

                Default: None

        attributions (tensor or tuple of tensors):
                Attribution scores computed based on an attribution algorithm.
                This attribution scores can be computed using the implementations
                provided in the `captum.attr` package. Some of those attribution
                approaches are so called global methods, which means that
                they factor in model inputs' multiplier, as described in:
                https://arxiv.org/pdf/1711.06104.pdf
                Many global attribution algorithms can be used in local modes,
                meaning that the inputs multiplier isn't factored in the
                attribution scores.
                This can be done duing the definition of the attribution algorithm
                by passing `multipy_by_inputs=False` flag.
                For example in case of Integrated Gradients (IG) we can obtain
                local attribution scores if we define the constructor of IG as:
                ig = IntegratedGradients(multipy_by_inputs=False)

                Some attribution algorithms are inherently local.
                Examples of inherently local attribution methods include:
                Saliency, Guided GradCam, Guided Backprop and Deconvolution.

                For local attributions we can use real-valued perturbations
                whereas for global attributions that perturbation is binary.
                https://arxiv.org/pdf/1901.09392.pdf

                If we want to compute the infidelity of global attributions we
                can use a binary perturbation matrix that will allow us to select
                a subset of features from `inputs` or `inputs - baselines` space.
                This will allow us to approximate sensitivity-n for a global
                attribution algorithm.

                `infidelity_perturb_func_decorator` function decorator is a helper
                function that computes perturbations under the hood if perturbed
                inputs are provided.

                For more details about how to use `infidelity_perturb_func_decorator`,
                please, read the documentation about `perturb_func`

                Attributions have the same shape and dimensionality as the inputs.
                If inputs is a single tensor then the attributions is a single
                tensor as well. If inputs is provided as a tuple of tensors
                then attributions will be tuples of tensors as well.

        additional_forward_args (any, optional): If the forward function
                requires additional arguments other than the inputs for
                which attributions should not be computed, this argument
                can be provided. It must be either a single additional
                argument of a Tensor or arbitrary (non-tuple) type or a tuple
                containing multiple additional arguments including tensors
                or any arbitrary python types. These arguments are provided to
                forward_func in order, following the arguments in inputs.
                Note that the perturbations are not computed with respect
                to these arguments. This means that these arguments aren't
                being passed to `perturb_func` as an input argument.

                Default: None
        target (int, tuple, tensor or list, optional): Indices for selecting
                predictions from output(for classification cases,
                this is usually the target class).
                If the network returns a scalar value per example, no target
                index is necessary.
                For general 2D outputs, targets can be either:

                - A single integer or a tensor containing a single
                  integer, which is applied to all input examples

                - A list of integers or a 1D tensor, with length matching
                  the number of examples in inputs (dim 0). Each integer
                  is applied as the target for the corresponding example.

                  For outputs with > 2 dimensions, targets can be either:

                - A single tuple, which contains #output_dims - 1
                  elements. This target index is applied to all examples.

                - A list of tuples with length equal to the number of
                  examples in inputs (dim 0), and each tuple containing
                  #output_dims - 1 elements. Each tuple is applied as the
                  target for the corresponding example.

                Default: None
        n_perturb_samples (int, optional): The number of times input tensors
                are perturbed. Each input example in the inputs tensor is expanded
                `n_perturb_samples`
                times before calling `perturb_func` function.

                Default: 10
        max_examples_per_batch (int, optional): The number of maximum input
                examples that are processed together. In case the number of
                examples (`input batch size * n_perturb_samples`) exceeds
                `max_examples_per_batch`, they will be sliced
                into batches of `max_examples_per_batch` examples and processed
                in a sequential order. If `max_examples_per_batch` is None, all
                examples are processed together. `max_examples_per_batch` should
                at least be equal `input batch size` and at most
                `input batch size * n_perturb_samples`.

                Default: None
    Returns:

        infidelities (tensor): A tensor of scalar infidelity scores per
                input example. The first dimension is equal to the
                number of examples in the input batch and the second
                dimension is one.

    Examples::
        >>> # ImageClassifier takes a single input tensor of images Nx3x32x32,
        >>> # and returns an Nx10 tensor of class probabilities.
        >>> net = ImageClassifier()
        >>> saliency = Saliency(net)
        >>> input = torch.randn(2, 3, 32, 32, requires_grad=True)
        >>> # Computes saliency maps for class 3.
        >>> attribution = saliency.attribute(input, target=3)
        >>> # define a perturbation function for the input
        >>> def perturb_fn(inputs):
        >>>    noise = torch.tensor(np.random.normal(0, 0.003, inputs.shape)).float()
        >>>    return noise, inputs - noise
        >>> # Computes infidelity score for saliency maps
        >>> infid = infidelity(net, perturb_fn, input, attribution)
    """

    def _generate_perturbations(
        current_n_perturb_samples: int,
    ) -> Tuple[TensorOrTupleOfTensorsGeneric, TensorOrTupleOfTensorsGeneric]:
        r"""
        The perturbations are generated for each example
        `current_n_perturb_samples` times.

        For performance reasons we are not calling `perturb_func` on each example but
        on a batch that contains `current_n_perturb_samples`
        repeated instances per example.
        """

        def call_perturb_func():
            r""""""
            baselines_pert = None
            inputs_pert: Union[Tensor, Tuple[Tensor, ...]]
            if len(inputs_expanded) == 1:
                inputs_pert = inputs_expanded[0]
                if baselines_expanded is not None:
                    baselines_pert = cast(Tuple, baselines_expanded)[0]
            else:
                inputs_pert = inputs_expanded
                baselines_pert = baselines_expanded
            return (
                perturb_func(inputs_pert, baselines_pert)
                if baselines_pert is not None
                else perturb_func(inputs_pert)
            )

        inputs_expanded = tuple(
            torch.repeat_interleave(input, current_n_perturb_samples, dim=0)
            for input in inputs
        )

        baselines_expanded = baselines
        if baselines is not None:
            baselines_expanded = tuple(
                baseline.repeat_interleave(current_n_perturb_samples, dim=0)
                if isinstance(baseline, torch.Tensor)
                and baseline.shape[0] == input.shape[0]
                and baseline.shape[0] > 1
                else baseline
                for input, baseline in zip(inputs, cast(Tuple, baselines))
            )

        return call_perturb_func()

    def _validate_inputs_and_perturbations(
        inputs: Tuple[Tensor, ...],
        inputs_perturbed: Tuple[Tensor, ...],
        perturbations: Tuple[Tensor, ...],
    ) -> None:
        # asserts the sizes of the perturbations and inputs
        assert len(perturbations) == len(inputs), (
            """The number of perturbed
            inputs and corresponding perturbations must have the same number of
            elements. Found number of inputs is: {} and perturbations:
            {}"""
        ).format(len(perturbations), len(inputs))

        # asserts the shapes of the perturbations and perturbed inputs
        for perturb, input_perturbed in zip(perturbations, inputs_perturbed):
            assert perturb[0].shape == input_perturbed[0].shape, (
                """Perturbed input
                and corresponding perturbation must have the same shape and
                dimensionality. Found perturbation shape is: {} and the input shape
                is: {}"""
            ).format(perturb[0].shape, input_perturbed[0].shape)

    def _next_infidelity(current_n_perturb_samples: int) -> Tensor:
        perturbations, inputs_perturbed = _generate_perturbations(
            current_n_perturb_samples
        )

        perturbations = _format_tensor_into_tuples(perturbations)
        inputs_perturbed = _format_tensor_into_tuples(inputs_perturbed)

        _validate_inputs_and_perturbations(
            cast(Tuple[Tensor, ...], inputs),
            cast(Tuple[Tensor, ...], inputs_perturbed),
            cast(Tuple[Tensor, ...], perturbations),
        )

        targets_expanded = _expand_target(
            target,
            current_n_perturb_samples,
            expansion_type=ExpansionTypes.repeat_interleave,
        )
        additional_forward_args_expanded = _expand_additional_forward_args(
            additional_forward_args,
            current_n_perturb_samples,
            expansion_type=ExpansionTypes.repeat_interleave,
        )

        inputs_perturbed_fwd = _run_forward(
            forward_func,
            inputs_perturbed,
            targets_expanded,
            additional_forward_args_expanded,
        )
        inputs_fwd = _run_forward(forward_func, inputs, target, additional_forward_args)
        inputs_fwd = torch.repeat_interleave(
            inputs_fwd, current_n_perturb_samples, dim=0
        )
        inputs_minus_perturb = inputs_fwd - inputs_perturbed_fwd
        attributions_expanded = tuple(
            torch.repeat_interleave(attribution, current_n_perturb_samples, dim=0)
            for attribution in attributions
        )
        attributions_times_perturb = tuple(
            (attribution_expanded * perturbation).view(attribution_expanded.size(0), -1)
            for attribution_expanded, perturbation in zip(
                attributions_expanded, perturbations
            )
        )

        attribution_times_perturb_sums = sum(
            [
                torch.sum(attribution_times_perturb, dim=1)
                for attribution_times_perturb in attributions_times_perturb
            ]
        )

        return torch.sum(
            torch.pow(
                attribution_times_perturb_sums - inputs_minus_perturb.view(-1), 2
            ).view(bsz, -1),
            dim=1,
        )

    # perform argument formattings
    inputs = _format_input(inputs)  # type: ignore
    if baselines is not None:
        baselines = _format_baseline(baselines, cast(Tuple[Tensor, ...], inputs))
    additional_forward_args = _format_additional_forward_args(additional_forward_args)
    attributions = _format_tensor_into_tuples(attributions)  # type: ignore

    # Make sure that inputs and corresponding attributions have matching sizes.
    assert len(inputs) == len(attributions), (
        """The number of tensors in the inputs and
        attributions must match. Found number of tensors in the inputs is: {} and in the
        attributions: {}"""
    ).format(len(inputs), len(attributions))
    for inp, attr in zip(inputs, attributions):
        assert inp.shape == attr.shape, (
            """Inputs and attributions must have
        matching shapes. One of the input tensor's shape is {} and the
        attribution tensor's shape is: {}"""
        ).format(inp.shape, attr.shape)

    bsz = inputs[0].size(0)
    with torch.no_grad():
        metrics_sum = _divide_and_aggregate_metrics(
            cast(Tuple[Tensor, ...], inputs),
            n_perturb_samples,
            _next_infidelity,
            max_examples_per_batch=max_examples_per_batch,
        )
    return metrics_sum * 1 / n_perturb_samples
Пример #60
0
    def forward(self, y, x):
        # Calculate rho_0
        with torch.no_grad():
            # Dimensions: [Batch, Time, Embedding]
            sin_phi_D = x[:, :self.N - 1, 1]
            cos_phi_D = x[:, :self.N - 1, 3]
            # exp_phi_D = cos_phi_D + 1j * sin_phi_D
            rho_0_real = torch.full((x.shape[0], 2, 2), 0.5)
            rho_0_imag = torch.zeros((x.shape[0], 2, 2))
            rho_0_real[:, 1, 0] *= cos_phi_D[:, 0]
            rho_0_real[:, 0, 1] *= cos_phi_D[:, 0]
            rho_0_imag[:, 1, 0] = 0.5 * sin_phi_D[:, 0]
            rho_0_imag[:, 0, 1] = -0.5 * sin_phi_D[:, 0]

        # Calculate unitary evolution operators

        # No grad for drive side
        with torch.no_grad():
            # H_D: [batch, time, 2x2 matrix]
            sin_theta_D = x[:, :self.N - 1, 0]

            # H_D = torch.zeros((x.shape[0], self.N - 1, 2, 2), dtype=torch.cdouble)

            # H_D[:, :, 1, 0] = exp_phi_D * sin_theta_D / 2
            # H_D[:, :, 0, 1] = torch.conj(exp_phi_D) * sin_theta_D / 2

            H_D_real = torch.zeros((x.shape[0], self.N - 1, 2, 2))
            H_D_imag = torch.zeros((x.shape[0], self.N -1, 2, 2))

            H_D_real[:, :, 1, 0] = 0.5 * sin_theta_D * cos_phi_D
            H_D_real[:, :, 0, 1] = 0.5 * sin_theta_D * cos_phi_D

            H_D_imag[:, :, 1, 0] = 0.5 * sin_theta_D * sin_phi_D
            H_D_imag[:, :, 0, 1] = -0.5 * sin_theta_D * sin_phi_D


        # H_T: [batch, time, 2x2]
        # H_T = torch.zeros((y.shape[0], self.N, 2, 2), dtype=torch.cdouble)
        theta_T = torch.atan2(y[:, :, 0], y[:, :, 2])
        phi_T = torch.atan2(y[:, :, 1], y[:, :, 3])

        # H_T[:, :, 1, 0] = (torch.cos(phi_T) + 1j * torch.sin(phi_T)) * torch.sin(theta_T) / 2
        # H_T[:, :, 0, 1] = (torch.cos(phi_T) - 1j * torch.sin(phi_T)) * torch.sin(theta_T) / 2

        H_T_real = torch.zeros((y.shape[0], self.N, 2, 2))
        H_T_imag = torch.zeros((y.shape[0], self.N, 2, 2))

        H_T_real[:, :, 1, 0] = torch.cos(phi_T) * torch.sin(theta_T) / 2
        H_T_real[:, :, 0, 1] = torch.cos(phi_T) * torch.sin(theta_T) / 2

        H_T_imag[:, :, 1, 0] = torch.sin(phi_T) * torch.sin(theta_T) / 2
        H_T_imag[:, :, 0, 1] = -1 * torch.sin(phi_T) * torch.sin(theta_T) / 2


        # Abs value of alpha = delta + tau
        # shape: [batch, 2]
        abs_alpha = torch.sqrt(torch.pow(torch.sin(theta_T[:, :self.N - 1]), 2) + torch.pow(sin_theta_D[:, :self.N - 1], 2) + 2 * torch.sin(theta_T[:, :self.N - 1]) * sin_theta_D[:, :self.N - 1] * (torch.cos(phi_T[:, :self.N - 1]) * cos_phi_D[:, :self.N - 1] + torch.sin(phi_T[:, :self.N - 1]) * sin_phi_D[:, :self.N - 1])) / 2

        alpha_real = 0.5 * (sin_theta_D[:, :self.N - 1] * cos_phi_D[:, :self.N - 1] + torch.sin(theta_T[:, :self.N - 1]) * torch.cos(phi_T[:, :self.N - 1]))
        alpha_imag = 0.5 * (sin_theta_D[:, :self.N - 1] * sin_phi_D[:, :self.N - 1] + torch.sin(theta_T[:, :self.N - 1]) * torch.sin(phi_T[:, :self.N - 1]))

        # Unitary evolution operator
        # Shape: [batch, N-1, 2, 2]

        U_real = torch.zeros((x.shape[0], self.N - 1, 2, 2))
        U_imag = torch.zeros((x.shape[0], self.N - 1, 2, 2))

        # Helpers
        c = torch.cos(abs_alpha * self.dt)
        s = torch.div(torch.sin(abs_alpha * self.dt), abs_alpha)

        U_real[:, :, 0, 0] = c
        U_real[:, :, 1, 1] = c
        U_real[:, :, 0, 1] = torch.mul(-1 * alpha_imag, s)
        U_real[:, :, 1, 0] = torch.mul(alpha_imag, s)

        U_imag[:, :, 0, 1] = torch.mul(-1 * alpha_real, s)
        U_imag[:, :, 1, 0] = torch.mul(-1 * alpha_real, s)


        # U_1 = matexp(H_D[:, 0] + H_T[:, 0], self.dt)
        # U_2 = matexp(H_D[:, 1] + H_T[:, 1], self.dt)


        # helper_real, helper_imag = real_matmul(U_real[:, 0], U_imag[:, 0], rho_0_real, rho_0_imag)
        # rho_1_real, rho_1_imag = real_matmul(helper_real, helper_imag, torch.transpose(U_real[:, 0], 1, 2), -1 * torch.transpose(U_imag[:, 0], 1, 2))
        #
        # A_1_real, A_1_imag = real_matmul(rho_1_real, rho_1_imag, (H_T_real[:, 1] - H_T_real[:, 0]), (H_T_imag[:, 1] - H_T_imag[:, 0]))
        # W_1 = A_1_real[:, 0, 0] + A_1_real[:, 1, 1]
        #
        # help2_real, help2_imag = real_matmul(U_real[:, 1], U_imag[:, 1], rho_1_real, rho_1_imag)
        # rho_2_real, rho_2_imag = real_matmul(help2_real, help2_imag, torch.transpose(U_real[:, 1], 1, 2), torch.transpose(-1 * U_imag[:, 1], 1, 2))
        #
        # A_2_real, A_2_imag = real_matmul(rho_2_real, rho_2_imag, H_T_real[:, 2] - H_T_real[:, 1], H_T_imag[:, 2] - H_T_imag[:, 1])
        # W_2 = A_2_real[:, 0, 0] + A_2_real[:, 1, 1]

        W = torch.zeros((x.shape[0]))

        rho_real = rho_0_real
        rho_imag = rho_0_imag

        for i in range(self.N - 1):
            helper_real, helper_imag = real_matmul(U_real[:, i], U_imag[:, i], rho_real, rho_imag)
            rho_real, rho_imag = real_matmul(helper_real, helper_imag, torch.transpose(U_real[:, i], 1, 2), -1 * torch.transpose(U_imag[:, i], 1, 2))
            A_real, A_imag = real_matmul(rho_real, rho_imag, (H_T_real[:, i+1] - H_T_real[:, i]), (H_T_imag[:, i+1] - H_T_imag[:, i]))
            W += A_real[:, 0, 0] + A_real[:, 1, 1]


        return torch.mean(W)