def val_epoch(self):
        self.classifier.eval()
        val_per_epoch = len(self.val_loader)
        print('Val', end='\t')
        val_pbar = pkbar.Kbar(target=val_per_epoch, epoch=self.start_epoch, num_epochs=self.epochs)
        total_loss = 0
        total_correct = 0
        total = 0
        for i, data in enumerate(self.val_loader):
            with torch.no_grad():
                x, label = data
                ids, atten_mask = x[0].squeeze_(1).to('cuda'), x[1].squeeze_(1).to('cuda')
                label = label.to('cuda')
                pred = self.classifier(ids, atten_mask)
                pred_classes = torch.argmax(pred,dim=1)
                loss = self.criterion(pred, label)
                correct = (label == pred_classes).sum()
                total_loss += loss.item()
                total_correct += correct
                total += label.shape[0]
                val_pbar.update(i, values=[('step loss', loss), ('step accuray', correct/label.shape[0])])

        epoch_loss = total_loss / (len(self.val_loader))
        epoch_acc = (total_correct/total)*100
        val_pbar.add(1, values=[('epoch loss', epoch_loss), ('epoch accuray', epoch_acc)])
        self.writer.add_scalar('val/loss', epoch_loss, global_step=self.val_epoch_step)
        self.writer.add_scalar('val/accuracy', epoch_acc, global_step=self.val_epoch_step)
        self.val_epoch_step += 1
Exemple #2
0
def eval_bin_classifier(model, data_loader, args, writer=None):
    """Evaluate model on val/test data."""
    model.eval()
    device = args.device
    kbar = pkbar.Kbar(target=len(data_loader), width=25)
    gt = 0
    pred = 0
    '''
    cam = GradCAM(
        model=model, target_layer=model.net.layer4[-1],
        use_cuda=True if torch.cuda.is_available() else False
    )
    '''
    for step, ex in enumerate(data_loader):
        images, _, _, neg_images = ex
        labels = torch.cat((
            torch.ones(len(images)), torch.zeros(len(neg_images))
        )).numpy()
        images = torch.cat((images, neg_images)).to(device)
        with torch.no_grad():
            pred += (
                (torch.sigmoid(model(images)).squeeze(-1).cpu().numpy() > 0.5) * 1
                == labels
            ).sum().item()
        gt += len(images)
        kbar.update(step)
        if step > 0:
            continue
        # Log
        '''
        writer.add_image(
            'image_sample',
            back2color(unnormalize_imagenet_rgb(images[0], device)),
            step
        )
        grayscale_cam = cam(
            input_tensor=images[0:1],
            target_category=0
        )
        grayscale_cam = grayscale_cam[0]
        heatmap = cv2.cvtColor(
            cv2.applyColorMap(np.uint8(255*grayscale_cam), cv2.COLORMAP_JET),
            cv2.COLOR_BGR2RGB
        )
        heatmap = torch.from_numpy(np.float32(heatmap) / 255).to(device)
        rgb_img = unnormalize_imagenet_rgb(images[0], device)
        rgb_cam_vis = heatmap.permute(2, 0, 1).contiguous() + rgb_img
        rgb_cam_vis = rgb_cam_vis / torch.max(rgb_cam_vis).item()
        writer.add_image(
            'image_grad_cam',
            back2color(rgb_cam_vis),
            step
        )
        '''

    print(f"\nAccuracy: {pred / gt}")
    return pred / gt
def eval_classifier(model, data_loader, args, writer=None, epoch=0):
    """Evaluate model on val/test data."""
    model.eval()
    #model.enable_all_grads()
    device = args.device
    kbar = pkbar.Kbar(target=len(data_loader), width=25)
    gt = []
    pred = []
    cam = GradCAM(
        model=model, target_layer=model.net.layer4[-1],
        use_cuda=True if torch.cuda.is_available() else False
    )
    for step, ex in enumerate(data_loader):
        images, _, emotions, _ = ex
        images = images.to(device)
        pred.append(torch.sigmoid(model(images)).detach().cpu().numpy())
        gt.append(emotions.cpu().numpy())
        kbar.update(step)
        # Log
        writer.add_image(
            'image_sample',
            back2color(unnormalize_imagenet_rgb(images[1], device)),
            epoch * len(data_loader) + step
        )
        for emo_id in torch.nonzero(emotions[1]).reshape(-1):
            grayscale_cam = cam(
                input_tensor=images[1:2],
                target_category=emo_id.item()
            )
            #grayscale_cam = grayscale_cam[0]
            '''
            writer.add_image(
                'gray_grad_cam_{}'.format(emo_id.item()),
                torch.from_numpy(np.uint8(255*grayscale_cam)).unsqueeze(0).repeat(3,1,1),
                epoch * len(data_loader) + step
            )
            '''
            heatmap = cv2.cvtColor(
                cv2.applyColorMap(np.uint8(255*grayscale_cam), cv2.COLORMAP_JET),
                cv2.COLOR_BGR2RGB
            )
            heatmap = torch.from_numpy(np.float32(heatmap) / 255).to(device)
            rgb_img = unnormalize_imagenet_rgb(images[1], device)
            rgb_cam_vis = heatmap.permute(2, 0, 1).contiguous() + rgb_img
            rgb_cam_vis = rgb_cam_vis / torch.max(rgb_cam_vis).item()
            writer.add_image(
                'image_grad_cam_{}'.format(emo_id.item()),
                back2color(rgb_cam_vis),
                epoch * len(data_loader) + step
            )
    AP = compute_ap(np.concatenate(gt), np.concatenate(pred))

    print(f"\nAccuracy: {np.mean(AP)}")
    print(AP)
    #model.zero_grad()
    #model.disable_all_grads()
    return np.mean(AP)
    def train(self,
              train_data_loader: DataLoader,
              train_num_examples: int,
              val_data_loader: DataLoader,
              val_num_examples: int,
              output_dir: str = './saved_models/',
              num_epochs: int = 2,
              lr: float = 2e-5):
        """

        :param train_data_loader:
        :param val_data_loader:
        :param num_epochs:
        :param lr:
        :return:
        """

        history = defaultdict(list)
        best_accuracy = 0
        optimizer = AdamW(model.parameters(), lr=lr)
        total_steps = len(train_data_loader) * num_epochs

        scheduler = get_linear_schedule_with_warmup(
            optimizer, num_warmup_steps=0, num_training_steps=total_steps)

        loss_fn = nn.CrossEntropyLoss().to(self.device)

        print("Running " + str(num_epochs) + " training epochs\n")
        for epoch in range(num_epochs):
            print('Epoch: %d/%d' % (epoch + 1, num_epochs))
            kbar = pkbar.Kbar(target=len(train_data_loader), width=35)

            train_acc, train_loss = self.train_epoch(train_data_loader,
                                                     loss_fn, optimizer,
                                                     scheduler,
                                                     train_num_examples, kbar)

            val_acc, val_loss = self.eval_model(val_data_loader, loss_fn,
                                                val_num_examples)

            history['train_acc'].append(train_acc)
            history['train_loss'].append(train_loss)
            history['val_acc'].append(val_acc)
            history['val_loss'].append(val_loss)

            if val_acc > best_accuracy:
                logger.info("")
                os.makedirs(output_dir, exist_ok=True)
                torch.save(self.model,
                           os.path.join(output_dir, 'pytorch_model.bin'))
                best_accuracy = val_acc

            kbar.add(1,
                     values=[("loss", train_loss), ("acc. ", train_acc),
                             ("val_loss", val_loss), ("val_acc. ", val_acc)])
def train(model, train_loader, criterion, optimizer, epoch, train_batch_num,
          writer):
    train_start = time.time()
    model.train()

    epoch_perplexity = 0
    epoch_loss = 0

    kbar = pkbar.Kbar(train_batch_num)

    for batch, (padded_input, padded_target, padded_decoder, input_lens,
                target_lens) in enumerate(train_loader):
        with torch.autograd.set_detect_anomaly(True):
            optimizer.zero_grad()
            batch_size = len(input_lens)
            vocab_size = model.vocab_size
            max_len = max(target_lens)

            padded_input = padded_input.to(DEVICE)
            padded_target = padded_target.type(torch.LongTensor).to(DEVICE)
            padded_decoder = padded_decoder.type(torch.LongTensor).to(DEVICE)

            predictions = model(padded_input, input_lens, epoch,
                                padded_decoder)

            mask = torch.arange(max_len).unsqueeze(0) < torch.tensor(
                target_lens).unsqueeze(1)
            mask = mask.type(torch.float64)
            mask.requires_grad = True
            mask = mask.reshape(batch_size * max_len).to(DEVICE)

            predictions = predictions.reshape(batch_size * max_len,
                                              vocab_size).contiguous()
            padded_target = padded_target.reshape(batch_size *
                                                  max_len).contiguous()

            loss = criterion(predictions, padded_target)
            masked_loss = torch.sum(loss * mask)
            batch_loss = masked_loss / torch.sum(mask).item()
            batch_loss.backward()
            epoch_loss += batch_loss.item()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 2)
            optimizer.step()
            perplexity = np.exp(batch_loss.item())
            epoch_perplexity += perplexity
            kbar.update(batch, values=[("loss", batch_loss)])

    kbar.add(1)
    writer.add_scalar('Loss/Train', epoch_loss / train_batch_num, epoch)
    writer.add_scalar('Perplexity/Train', epoch_perplexity / train_batch_num,
                      epoch)

    return epoch_loss / train_batch_num
    def train(self, train_dataset, valid_dataset, epochs=1):
        es = EarlyStopping(patience=5)
        scheduler = ReduceLROnPlateau(self.optimizer,
                                      'min',
                                      patience=2,
                                      verbose=True)
        train_loss, best_val_loss = 0.0, float(1e4)
        epoch, step = 0, 0
        for epoch in range(1, epochs + 1):
            epoch_loss = 0.0
            print(f'Epoch: {epoch}/{epochs}')
            bar = pkbar.Kbar(target=len(train_dataset))
            for step, sample in enumerate(train_dataset):
                self.model.train()
                inputs, labels, pos = sample['inputs'].to(
                    self._device), sample['outputs'].to(
                        self._device), sample['pos'].to(self._device)
                mask = (inputs != 0).to(self._device, dtype=torch.uint8)
                self.optimizer.zero_grad()
                # Pass the inputs directly, log_probabilities already calls forward
                sample_loss = -self.model.log_probs(inputs, labels, mask, pos)
                sample_loss.backward()
                clip_grad_norm_(self.model.parameters(),
                                5.0)  # Gradient Clipping
                self.optimizer.step()
                epoch_loss += sample_loss.tolist()
                bar.update(step, values=[("loss", sample_loss.item())])
            avg_epoch_loss = epoch_loss / len(train_dataset)
            train_loss += avg_epoch_loss
            valid_loss = self.evaluate(valid_dataset)
            bar.add(1, values=[("loss", train_loss), ("val_loss", valid_loss)])
            if self.writer:
                self.writer.set_step(epoch, 'train')
                self.writer.add_scalar('loss', epoch_loss)
                self.writer.set_step(epoch, 'valid')
                self.writer.add_scalar('val_loss', valid_loss)

            is_best = valid_loss <= best_val_loss
            if is_best:
                logging.info("Model Checkpoint saved")
                best_val_loss = valid_loss
                model_dir = os.path.join(os.getcwd(), 'model',
                                         f'{self.model.name}_ckpt_best')
                self.model.save_checkpoint(model_dir)
            scheduler.step(valid_loss)
            if es.step(valid_loss):
                print(f"Early Stopping activated on epoch #: {epoch}")
                break

        avg_epoch_loss = train_loss / epochs
        return avg_epoch_loss
Exemple #7
0
    def Train(self):

        num_of_batches_per_epoch = int(self.train_len / self.tr_b_sz) + 1
        train_loss = 0
        tgt = 10
        criterion = nn.CrossEntropyLoss()
        # optimizer = optim.SGD(self.model.parameters(), lr=0.001, momentum=0.9)
        optimizer = torch.optim.Adam(self.model.parameters())
        for e in range(self.epochs):
            correct = 0
            total = 0
            kbar = pkbar.Kbar(target=num_of_batches_per_epoch,
                              stateful_metrics=['Loss', 'Accuracy'],
                              width=30)
            self.model.train()
            for batch_idx, (X, Y) in enumerate(self.train_loader):

                X, Y = X.to(DEVICE), Y.to(DEVICE)
                optimizer.zero_grad()
                if self.is_bayesian:
                    loss = self.model.sample_elbo(inputs=X,
                                                  labels=Y,
                                                  criterion=criterion,
                                                  sample_nbr=3,
                                                  complexity_cost_weight=1 /
                                                  50000)
                else:
                    outputs = self.model(X)
                    loss = criterion(outputs, Y)

                loss.backward()
                # parameter update
                optimizer.step()
                train_loss += loss.item()
                if self.is_bayesian:
                    outputs = self.model(X)

                _, predicted = outputs.max(1)
                total += Y.size(0)
                correct += predicted.eq(Y).sum().item()
                # if (e+1) % 10 == 0:
                kbar.update(batch_idx + 1,
                            values=[("Epoch", e + 1), ("Loss", loss.item()),
                                    ("Train Accuracy", 100. * correct / total)
                                    ])
            if (e + 1) % 10 == 0:
                print('', end=" ")

        torch.save(self.model.state_dict(), self.train_weight_path)
        print('Trained Weights are Written to {} file'.format(
            self.train_weight_path))
    def train(self):
        # wandb.watch(self.classifier)
        self.resume_training()
        train_per_epoch = len(self.train_loader)
        for epoch in range(self.start_epoch, self.epochs):
            print('Train',end='\t')
            torch.save(self.classifier, 'checkpoints/bert_model_classification_epoch_{}.pth'.format(epoch))
            pbar = pkbar.Kbar(target=train_per_epoch, epoch=epoch, num_epochs=self.epochs)
            total_loss = 0
            total_correct = 0
            total = 0
            print('Train',end='\t')
            for i, data in enumerate(self.train_loader):
                x, label = data
                ids, atten_mask = x[0].squeeze_(1).to('cuda'), x[1].squeeze_(1).to('cuda')
                label = label.to('cuda')
                self.optim.zero_grad()
                pred = self.classifier(ids, atten_mask)
                pred_classes = torch.argmax(pred,dim=1)
                loss = self.criterion(pred, label)
                loss.backward()
                self.optim.step()
                correct = (label == pred_classes).sum()
                total_loss += loss.item()
                total_correct += correct
                total += label.shape[0]

                # if i % 2000 == 0:
                #     print('Info :')
                #     print(f'Predicted {pred_classes}',end=' ')
                #     print(f'correct {label}', end=' ')
                #     print(f'Accuracy {correct/label.shape[0]}', end=' ')
                #     print(f'Loss {loss}', end='\n')

                pbar.update(i, values=[('step loss', loss), ('step accuray', correct/label.shape[0])])
            # print('{}|{}: loss: {}, accuracy: {}'.format(epoch, epochs, total_loss, total_acc))
            epoch_loss = total_loss / len(self.train_loader)
            epoch_acc = (total_correct/total)*100
            pbar.add(1, values=[('epoch loss', epoch_loss), ('epoch accuray', epoch_acc)])
            self.writer.add_scalar('train/loss', epoch_loss, global_step=epoch)
            self.writer.add_scalar('train/accuracy', epoch_acc, global_step=epoch)
            ckpoint = 'checkpoints/bert_model_classification_epoch_{}.pth'.format(epoch)
            torch.save(self.classifier.state_dict(), ckpoint)
            print(f'checkpoints saved at {ckpoint}')
            self.logwriter = open(self.log_file, 'a')
            self.logwriter.writelines(f'\n{ckpoint}')
            self.logwriter.flush()
            # self.classifier.save('bert_model_classification_epoch_{}.pth'.format(epoch))
            self.val_epoch()
            self.start_epoch += 1
Exemple #9
0
 def train(self):
     for epoch in range(self.epochs):
         _metrics = self.model.get_metrics(reset=True)
         kbar = pkbar.Kbar(
             target=len(self.training_dataloader),
             epoch=epoch, num_epochs=self.epochs,
             stateful_metrics=list(_metrics))
         self.model.train()
         for i, batch in enumerate(self.training_dataloader):
             _loss = self.train_step(batch)
             _metrics = self.model.get_metrics(reset=True)
             kbar.update(i, values=[("loss", _loss)]+[(k, v)
                                                      for k, v in _metrics.items()])
         val_loss, val_metrics = self.evaluate()
         kbar.add(1, values=[("val_loss", val_loss)] +
                  [(f"val_{k}", v) for k, v in val_metrics.items()])
Exemple #10
0
    def train(self, n_epochs, learning_rate):
        dl_len = len(self.data_loader)

        self.model.to(self.device)
        params = [p for p in self.model.parameters() if p.requires_grad]
        optimizer = torch.optim.SGD(
            params, lr=learning_rate, momentum=0.9, weight_decay=0.0005
        )

        losses_per_ep = []

        for epoch in range(n_epochs):
            self.model.train()

            ep_loss = 0
            kbar = pkbar.Kbar(
                target=dl_len,
                epoch=epoch,
                num_epochs=n_epochs,
                width=20,
                always_stateful=True,
            )

            for i, (images, annotations) in enumerate(self.data_loader):
                images = list(image.to(self.device) for image in images)
                annotations = [
                    {
                        k: v.to(self.device)
                        for k, v in t.items()
                    } for t in annotations
                ]

                losses = self.model([images[0]], [annotations[0]])
                loss = sum(loss for loss in losses.values())

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                ep_loss += loss.item()

                kbar.update(i, values=[("loss", ep_loss)])

            losses_per_ep.append(ep_loss)
            kbar.add(1)

        return losses_per_ep
Exemple #11
0
def validate(combined_model, unsupervised_val, retrievable_items):
    batches_per_epoch = len(unsupervised_val)
    kbar = pkbar.Kbar(target=batches_per_epoch, width=8)

    device = torch.device(
        "cuda") if torch.cuda.is_available() else torch.device("cpu")

    combined_model.eval()

    image_embeddings = []
    text_embeddings = []

    running_loss = 0.0
    print("Start validation")

    for indx, unsupervised_inputs in enumerate(unsupervised_val):
        unsupervised_image_inputs = unsupervised_inputs[0].to(device)
        unsupervised_text_inputs = unsupervised_inputs[1].to(device)

        with torch.set_grad_enabled(False):
            text_embeddings_unsupervised, image_embeddings_unsupervised = combined_model(
                unsupervised_text_inputs, unsupervised_image_inputs)

            unsupervised_loss = criterion_unsupervised(
                text_embeddings_unsupervised, image_embeddings_unsupervised)

            image_embeddings.append(
                image_embeddings_unsupervised.detach().clone())
            text_embeddings.append(
                text_embeddings_unsupervised.detach().clone())

        # statistics
        running_loss += unsupervised_loss
        kbar.update(indx, values=[("unsupervised_loss", unsupervised_loss)])

    recalls = []

    for item_count in retrievable_items:
        recalls.append((evaluate(text_embeddings, image_embeddings, 5,
                                 item_count), item_count))

    epoch_loss = running_loss / len(unsupervised_val)

    return epoch_loss, recalls
Exemple #12
0
    def Test(self):

        num_of_batches_per_epoch = int(self.test_len / self.tst_b_sz) + 1
        self.model.load_state_dict(torch.load(self.train_weight_path))
        correct = 0
        total = 0
        kbar = pkbar.Kbar(target=num_of_batches_per_epoch,
                          stateful_metrics=['Loss', 'Accuracy'],
                          width=11)
        self.model.eval()
        with torch.no_grad():
            for batch_idx, (X, Y) in enumerate(self.test_loader):
                X, Y = X.to(DEVICE), Y.to(DEVICE)
                outputs = self.model(X)
                _, predicted = outputs.max(1)
                total += Y.size(0)
                correct += predicted.eq(Y).sum().item()
                kbar.update(batch_idx + 1,
                            values=[("Accuracy", 100. * correct / total)])
def eval_generator(model, data_loader, args):
    """Evaluate model on val/test data."""
    model.eval()
    model.disable_batchnorm()
    device = args.device
    kbar = pkbar.Kbar(target=len(data_loader), width=25)
    gt = 0
    pred = 0
    for step, ex in enumerate(data_loader):
        images, _, _, neg_images = ex
        # Compute energy
        pos_out = model(images.to(device))
        neg_img_out = model(neg_images.to(device))
        gt += len(images)
        pred += (pos_out < neg_img_out).sum()
        kbar.update(step, [("acc", pred / gt)])

    print(f"\nAccuracy: {pred / gt}")
    return pred / gt
Exemple #14
0
    def train(self, num_frames: int, plotting_interval: int = 200):
        """Train the agent."""
        self.is_test = False

        state = self.env.reset()
        actor_losses = []
        critic_losses = []
        scores = []
        score = 0

        print("Training...")
        kbar = pkbar.Kbar(target=num_frames, width=20)

        for self.total_step in range(1, num_frames + 1):
            action = self.select_action(state)
            next_state, reward, done = self.step(action)

            state = next_state
            score += reward

            # if episode ends
            if done:
                state = self.env.reset()
                scores.append(score)
                score = 0

                self._plot(
                    self.total_step,
                    scores,
                    actor_losses,
                    critic_losses,
                )

            # if training is ready
            if (len(self.memory) >= self.batch_size):  # and
                actor_loss, critic_loss = self.update_model()
                actor_losses.append(actor_loss)
                critic_losses.append(critic_loss)

            kbar.add(1)

        self.env.close()
Exemple #15
0
def eval_transformations(model, data_loader, args):
    """Evaluate model on val/test data."""
    model.eval()
    model.disable_batchnorm()
    device = args.device
    kbar = pkbar.Kbar(target=len(data_loader), width=25)
    gt = 0
    pred = 0
    for step, ex in enumerate(data_loader):
        images, _, _, neg_images = ex
        # Compute energy
        pos_out = model(normalize_imagenet_rgb(images.to(device)))
        # negative samples
        neg_samples = rand_augment(images.clone().to(device))
        neg_img_out = model(normalize_imagenet_rgb(neg_samples.to(device)))
        gt += len(images)
        pred += (pos_out < neg_img_out).sum()
        kbar.update(step, [("acc", pred / gt)])

    print(f"\nAccuracy: {pred / gt}")
    return pred / gt
Exemple #16
0
    def populate(self, eps: int = 100) -> None:
        """
        Carries out several random steps through the environment to initially fill
        up the replay buffer with experiences

        Args:
            steps: number of random steps to populate the buffer with
        """

        if not self.is_test:
            print("Populate Replay Buffer... ")
            kbar = pkbar.Kbar(target=eps, width=20)
            state = self.env.reset()

            for i in range(eps):
                while True:
                    # Get action from sample space
                    selected_action = self.env.action_space.sample()
                    # selected_action = 0
                    noise = self.noise.sample()
                    selected_action = np.clip(selected_action + noise, -1.0,
                                              1.0)

                    next_state, reward, done, _ = self.env.step(
                        selected_action)
                    self.transition = [
                        state, selected_action, reward, next_state,
                        int(done)
                    ]
                    self.memory.append(Experience(*self.transition))

                    state = next_state
                    if done:
                        state = self.env.reset()
                        break

                kbar.add(1)
Exemple #17
0
optimizer_pruner = optim.Adam(pruner_net.parameters(), lr=args['lr_pruner'])

for image in glob.glob('run_details/*.png'):
    os.remove(image)

total_epochs = args['num_pretrain'] + args['num_first_stage'] + args[
    'num_second_stage']
checkpoints = set(
    np.cumsum([
        args['num_pretrain'], args['num_first_stage'], args['num_second_stage']
    ]) - 1)

for epoch in range(total_epochs):  # loop over the dataset multiple times
    kbar = pkbar.Kbar(target=len(trainloader),
                      epoch=epoch,
                      num_epochs=total_epochs,
                      width=12,
                      always_stateful=False)

    running_loss_main = 0.0
    running_loss_pruner = 0.0
    loss_pruner = 0.0

    if current_mode not in {'pretrain', 'weight', 'mask', 'both'}:
        raise (ValueError,
               'current_mode must be mask, weight, pretrain or both')

    for i, data in enumerate(trainloader, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data[0].to(device), data[1].to(device)
Exemple #18
0
## PRELOAD MODELS IF NOT THE FIRST EPOCH
if epoch != 0:
    generator.load_state_dict(torch.load("../input/celebhqmodels/celebaHQ_generator.pth"))
    discriminator.load_state_dict(torch.load("../input/celebhqmodels/celebaHQ_discriminator.pth"))

## LOSS FUNCTIONS AND OPT
mse_loss = torch.nn.MSELoss()
l1_loss = torch.nn.L1Loss()
optimizer_G = torch.optim.Adam(generator.parameters(), lr, betas=(b1, b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr, betas=(b1, b2))

for epoch in range(epoch, n_epochs):

    ## SET UP PROGRESS BAR
    kbar = pkbar.Kbar(target=len(data_loader), epoch=epoch, num_epochs=n_epochs, width=8, always_stateful=False)
    stateful_metrics=["G LOSS", "D LOSS"]

    for i, imgs in enumerate(data_loader):

        ## LOAD IMAGES
        img_hr, img_lr = imgs

        ## CONFIGURE THE MODEL OUTPUTS
        imgs_lr = Variable(img_lr.type(Tensor))
        imgs_hr = Variable(img_hr.type(Tensor))

        ## CONFIGURE GROUND TRUTHS
        valid = Variable(Tensor(np.ones((imgs_lr.size(0), 1))), requires_grad=False)
        fake = Variable(Tensor(np.zeros((imgs_lr.size(0), 1))), requires_grad=False)
Exemple #19
0
    def run(self):

        # self.optimizer = optim.Adagrad(self.model.parameters(), lr=1e-3, lr_decay=0, weight_decay=0, initial_accumulator_value=0, eps=1e-10)
        # optim.SGD(self.model.parameters(), lr=1e-3, momentum=0, dampening=0, weight_decay=0, nesterov=False)
        self.optimizer = optim.Adam(self.model.parameters(),
                                    lr=1e-3,
                                    betas=(0.9, 0.999),
                                    eps=1e-08,
                                    weight_decay=0,
                                    amsgrad=False)

        self.global_step = 0
        self.no_epoch = self.cfg['training']['epoch']

        for epoch in range(self.no_epoch):

            self.epoch = epoch
            tbar = tqdm(self.train_queue)
            print('Epoch: %d/%d' % (self.epoch + 1, self.no_epoch))
            self.logger.info('=> Epoch {}'.format(self.epoch))
            # train and search the model
            self.model.train()
            self.kbar = pkbar.Kbar(target=25, width=0)
            self.epoch_loss = 0
            self.trainloss = 0
            self.trainpsnr = 0

            for i, batch in enumerate(tbar):
                # print(batch)
                self.optimizer.zero_grad()

                noisy_imgs = batch[0]
                clean_image = batch[1]

                noisy_imgs = noisy_imgs.to(
                    self.device)  #, dtype=torch.float32)
                # mask_type = torch.float32 if net.n_classes == 1 else torch.long
                clean_image = clean_image.to(
                    self.device)  #, dtype=torch.float32)
                # print('imgs', imgs)
                clean_pred = self.model(noisy_imgs)
                # print('clean_pred', clean_pred)
                clean_pred = self.norm(clean_pred)
                # clean_pred = noisy_imgs - noise_pred
                # true_masks = noisy_imgs-noise   #clean_imgs

                self.loss = self.criterion(clean_pred, clean_image)
                psnr = self.cal_psnr(clean_pred, clean_image, 1.)

                # self.loss, psnr= self.criterion(noise_pred, noise, clean_pred, true_masks)
                # self.loss = self.loss.to(device=self.device)
                self.trainloss += self.loss
                self.trainpsnr += psnr
                print('Loss: ', self.loss.item())
                print('trainpsnr', psnr)
                self.epoch_loss += self.loss.item()
                self.writer.add_scalar('Loss/train', self.loss.item(),
                                       self.global_step)

                # pbar.set_postfix(**{'loss (batch)': loss.item()})

                if self.cfg['training']['grad_clip']:
                    nn.utils.clip_grad_norm_(self.model.parameters(),
                                             self.cfg['training']['grad_clip'])

                # self.optimizer.zero_grad()
                self.loss.backward()
                self.optimizer.step()

                self.kbar.update(i,
                                 values=[("loss",
                                          self.loss.detach().cpu().numpy())])

            self.trainloss = self.trainloss / (len(self.train_queue))
            self.trainpsnr = self.trainpsnr / (len(self.train_queue))

            self.logger.info('TrainLoss : {}'.format(self.trainloss))
            self.logger.info('Trainpsnr : {}'.format(self.trainpsnr))

            # print(torch.sum(list(self.model.parameters())[0]))
            # for param in self.model.parameters():
            #     print(param.data)

            # valid the model
            print('Starting validation....')

            self.model.eval()
            self.tot = 0
            self.tot_psnr = 0

            self.global_step += 1

            tbar_val = tqdm(self.valid_queue)

            for i, batch in enumerate(tbar_val):

                imgs = batch[0]
                clean_image_val = batch[1]

                imgs = imgs.to(self.device)  #, dtype=torch.float32)
                # mask_type = torch.float32 if net.n_classes == 1 else torch.long
                clean_image_val = clean_image_val.to(
                    self.device)  #, dtype=torch.float32)

                clean_pred_val = self.model(imgs)
                clean_pred_val = self.norm(clean_pred_val)
                # clean_pred1 = imgs-noise_pred_val
                # clean_img1 = imgs- noise
                self.val_loss = self.criterion(clean_pred_val, clean_image_val)
                self.psnr_val = self.cal_psnr(clean_pred_val, clean_image_val,
                                              1.)
                # self.val_loss, self.psnr_val = self.criterion(noise_pred_val, noise, clean_pred1, clean_img1)
                self.tot += self.val_loss  #.to(device=self.device)
                self.tot_psnr += self.psnr_val

            self.val_score = self.tot / (len(self.valid_queue))
            self.psnr_score = self.tot_psnr / (len(self.valid_queue))
            self.logger.info('ValidLoss : {}'.format(self.val_score))
            self.logger.info('Validpsnr : {}'.format(self.psnr_score))

            # self.logging.info('Validation Dice Coeff: {}'.format(self.val_score))
            # self.writer.add_scalar('Dice/test', self.val_score, self.global_step)

            self.writer.add_images('images', noisy_imgs,
                                   self.global_step)  #noisy
            self.writer.add_images('masks/true', clean_image,
                                   self.global_step)  #clean
            self.writer.add_images(
                'masks/pred', clean_pred,
                self.global_step)  #pred should be close to clean

            # self.kbar.add(i, values=[("val_score", self.val_score)])
            # self.kbar.add(i, values=[("PSNR_val", self.psnr_score)])
            print("val_score", self.val_score)
            print("PSNR_val", self.psnr_score)

            self.filename = 'ckpt_{0}'.format(self.epoch + 1) + '.pth.tar'
            torch.save(
                {
                    'epoch': self.epoch + 1,
                    'state_dict': self.model.state_dict()
                }, self.filename)
            shutil.move(self.filename, self.ckpt_path)

        # export scalar data to JSON for external processing
        self.writer.close()
        self.logger.info('log dir in : {}'.format(self.save_path))
Exemple #20
0
    def run_epoch(self, mode, model, optim=None, scheduler=None, epoch=None):

        assert mode in ['train', 'valid', 'test']
        if mode == 'train':
            loader = self.train_loader
            num_data = self.train_num_data

            if model.multi_model:
                target_iter = len(loader[0])
            else:
                target_iter = len(loader)
            kbar = pkbar.Kbar(target=target_iter,
                              epoch=epoch,
                              num_epochs=self.config.epoch,
                              width=16,
                              always_stateful=False)
        elif mode == 'valid':
            loader = self.valid_loader
            num_data = self.valid_num_data
        elif mode == 'test':
            loader = self.test_loader
            num_data = self.test_num_data

        progress = 0
        total_loss = 0
        total_correct = 0
        total_correct_first = 0
        # input.shape = [N, time_length, num_in_features]
        # target.shape = [N, time_length, num_target_features]

        if model.multi_model and not mode == 'test':
            if self.config.evaluation_mode:
                loader = loader[0]
            else:
                loader = zip(*loader)

        for batch_idx, inp_tar in enumerate(loader):
            if model.multi_model:
                if mode == 'test' or self.config.evaluation_mode:
                    input = inp_tar[0].unsqueeze(0).repeat(
                        model.num_models,
                        *[1 for i in range(inp_tar[0].dim())])
                    target = inp_tar[1].unsqueeze(0).repeat(
                        model.num_models,
                        *[1 for i in range(inp_tar[1].dim())])
                else:
                    input = torch.stack([item[0] for item in inp_tar])
                    target = torch.stack([item[1] for item in inp_tar])
            else:
                input = inp_tar[0]
                target = inp_tar[1]

            if self.cuda_enabled:
                input = input.cuda()
                target = target.cuda()

            if self.task == "mnist":
                if model.multi_model:
                    input = input.reshape(input.shape[0] * input.shape[1], -1)
                    input = self.float2spikes(input,
                                              model.time_length,
                                              self.config.max_input_timing,
                                              self.config.min_input_timing,
                                              type='latency',
                                              stochastic=False,
                                              last=False,
                                              skip_zero=True)
                    input = input.reshape(
                        model.num_models,
                        int(input.shape[0] / model.num_models),
                        *input.shape[1:])
                else:
                    input = input.reshape(input.shape[0], -1)
                    input = self.float2spikes(input,
                                              model.time_length,
                                              self.config.max_input_timing,
                                              self.config.min_input_timing,
                                              type='latency',
                                              stochastic=False,
                                              last=False,
                                              skip_zero=True)

            # Run forward pass.
            output = model(input)

            if mode == 'train':
                # Backward and update.
                optim.zero_grad()
                if self.config.target_type == 'latency':
                    model.backward_custom(target)
                else:
                    assert self.config.target_type == 'count'
                    target_spike = self.label2spikes(target.reshape(-1))
                    # model_batch x time x neuron
                    model.backward_custom(target_spike)
                optim.step()
            else:
                if self.config.target_type == 'latency':
                    model.calc_loss(target.reshape(-1))
                else:
                    assert self.config.target_type == 'count'
                    target_spike = self.label2spikes(target.reshape(-1))
                    # model_batch x time x neuron
                    model.calc_loss(target_spike)

            loss = model.loss

            batch_size = target.shape[-1]
            total_loss += loss * batch_size

            num_spike_total = model.num_spike_total
            num_spike_nec = model.num_spike_nec
            first_stime_min = model.first_stime_min
            first_stime_mean = model.first_stime_mean
            if batch_idx == 0:
                self.total_num_spike_total = num_spike_total
                self.total_num_spike_nec = num_spike_nec
                self.min_first_stime_min = first_stime_min
                self.mean_first_stime_mean = first_stime_mean
            else:
                if model.multi_model:
                    self.total_num_spike_total = [
                        (np.array(num_spike_total[i]) +
                         np.array(self.total_num_spike_total[i])).tolist()
                        for i in range(len(num_spike_total))
                    ]
                    self.total_num_spike_nec = [
                        (np.array(num_spike_nec[i]) +
                         np.array(self.total_num_spike_nec[i])).tolist()
                        for i in range(len(num_spike_nec))
                    ]
                    self.min_first_stime_min = [
                        min(x, y) for x, y in zip(self.min_first_stime_min,
                                                  first_stime_min)
                    ]
                    self.mean_first_stime_mean = (
                        (np.array(self.mean_first_stime_mean) * progress +
                         np.array(first_stime_mean) * batch_size) /
                        (progress + batch_size)).tolist()
                else:
                    self.total_num_spike_total = [
                        num_spike_total[i] + self.total_num_spike_total[i]
                        for i in range(len(num_spike_total))
                    ]
                    self.total_num_spike_nec = [
                        num_spike_nec[i] + self.total_num_spike_nec[i]
                        for i in range(len(num_spike_nec))
                    ]
                    self.min_first_stime_min = min(self.min_first_stime_min,
                                                   first_stime_min)
                    self.mean_first_stime_mean = (
                        self.mean_first_stime_mean * progress +
                        first_stime_mean * batch_size) / (progress +
                                                          batch_size)

            pred_class = self.spikes2label(output, 'count')
            pred_class_first = self.spikes2label(output, 'first')
            if model.multi_model:
                num_correct = (pred_class.reshape(
                    target.shape) == target).sum(1).float()
                num_correct_first = (pred_class_first.reshape(
                    target.shape) == target).sum(1).float()
                total_correct += num_correct
                total_correct_first += num_correct_first
            else:
                num_correct = (pred_class == target).sum().float()
                num_correct_first = (pred_class_first == target).sum().float()
                total_correct += float(num_correct.item())
                total_correct_first += float(num_correct_first.item())

            current_acc_count = num_correct / batch_size
            current_acc_first = num_correct_first / batch_size

            progress += batch_size
            # if mode == 'train':
            #     self.logger.log_train(model.multi_model, epoch, progress, loss, num_spike_total, num_spike_nec, first_stime_min, first_stime_mean, num_correct, num_correct_first, batch_size, model.term_length, (batch_idx % self.config.log_interval == 0))

            if mode == "train":
                kbar.update(batch_idx + 1,
                            values=[("loss", loss), ("acc", current_acc_count),
                                    ("acc_first", current_acc_first)])

        if mode == "train":
            return (total_loss / progress), (total_correct /
                                             progress), (total_correct_first /
                                                         progress), kbar
        else:
            return (total_loss / progress), (total_correct /
                                             progress), (total_correct_first /
                                                         progress)
    model = torch.nn.Sequential(aev_computer, nn).to(parser.device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.000001)
    mse = torch.nn.MSELoss(reduction='none')

    print('=> loading dataset...')
    shifter = torchani.EnergyShifter(None)
    dataset = list(
        torchani.data.load(parser.dataset_path).subtract_self_energies(
            shifter).species_to_indices().shuffle().collate(parser.batch_size))

    print('=> start warming up')
    total_batch_counter = 0
    for epoch in range(0, WARM_UP_BATCHES + 1):

        print('Epoch: %d/inf' % (epoch + 1, ))
        progbar = pkbar.Kbar(target=len(dataset) - 1, width=8)

        for i, properties in enumerate(dataset):

            if not parser.dry_run and total_batch_counter == WARM_UP_BATCHES:
                print('=> warm up finished, start profiling')
                enable_timers(model)
                torch.cuda.cudart().cudaProfilerStart()

            PROFILING_STARTED = not parser.dry_run and (total_batch_counter >=
                                                        WARM_UP_BATCHES)

            if PROFILING_STARTED:
                torch.cuda.nvtx.range_push(
                    "batch{}".format(total_batch_counter))
Exemple #22
0
def train(combined_model, supervised, unsupervised, optimizer, mmd_weight):
    batches_per_epoch = len(supervised)

    kbar = pkbar.Kbar(target=batches_per_epoch, width=8)

    device = torch.device(
        "cuda") if torch.cuda.is_available() else torch.device("cpu")

    combined_model.train()

    images_supervised = []
    text_supervised = []

    images_unsupervised = []
    text_unsupervised = []

    running_loss = 0.0
    running_loss_supervised = 0.0
    running_loss_unsupervised = 0.0

    print("Start training")

    for indx, (supervised_inputs,
               unsupervised_inputs) in enumerate(zip(supervised,
                                                     unsupervised)):
        img_inputs = supervised_inputs[0].to(device)
        text_inputs = supervised_inputs[1].to(device)

        unsupervised_image_inputs = unsupervised_inputs[0].to(device)
        unsupervised_text_inputs = unsupervised_inputs[1].to(device)

        optimizer.zero_grad()

        with torch.set_grad_enabled(True):
            text_embeddings_supervised, image_embeddings_supervised = combined_model(
                text_inputs, img_inputs)

            # unsupervised output
            text_embeddings_unsupervised, image_embeddings_unsupervised = combined_model(
                unsupervised_text_inputs, unsupervised_image_inputs)

            supervised_loss = criterion_supervised(
                text_embeddings_supervised, image_embeddings_supervised)
            unsupervised_loss = criterion_unsupervised(
                text_embeddings_unsupervised, image_embeddings_unsupervised)

            # Scale unsupervised loss with batch size
            unsupervised_loss = unsupervised_loss * mmd_weight

            loss = supervised_loss + unsupervised_loss

            loss.backward()
            optimizer.step()

            images_supervised.append(
                image_embeddings_supervised.detach().clone())
            text_supervised.append(text_embeddings_supervised.detach().clone())

            images_unsupervised.append(
                image_embeddings_unsupervised.detach().clone())
            text_unsupervised.append(
                text_embeddings_unsupervised.detach().clone())

        # statistics
        running_loss += loss.item()
        running_loss_supervised += supervised_loss
        running_loss_unsupervised += unsupervised_loss

        kbar.update(indx,
                    values=[("loss", loss),
                            ("supervised_loss", supervised_loss),
                            ("unsupervised_loss", unsupervised_loss)])

    recall_supervised = evaluate(text_supervised, images_supervised)
    recall_unsupervised = evaluate(text_unsupervised, images_unsupervised)

    epoch_loss = (running_loss / len(supervised),
                  running_loss_supervised / len(supervised),
                  running_loss_unsupervised / len(supervised))

    return epoch_loss, recall_supervised, recall_unsupervised
Exemple #23
0
 
 
 # Val loader
 val_loader = DataLoader (val_dataset, batch_size = batch_size, shuffle = False)
 
 
 # Define the size of the progress bar
 train_per_epoch = len (train_df) / batch_size
 
 
 # Train and eval each epoch
 for epoch in range (epochs):
     
     # Create a progress bar
     # @link https://github.com/yueyericardo/pkbar/blob/master/pkbar/pkbar.py (stateful metrics)
     kbar = pkbar.Kbar (target = train_per_epoch, width = 32, stateful_metrics = ['loss', 'acc', 'epoch', 'val_loss', 'val_acc'])
     
     
     # Store our loss and accuracy for plotting
     train_loss_set = []
     
     
     # Store correct predictions
     correct_predictions = 0
     i = 0
     
     
     # Train this epoch
     model.train ()
     
     
def val(model, val_loader, criterion, epoch, val_batch_num, index2letter,
        writer):
    model.eval()

    epoch_distance = 0
    epoch_perplexity = 0
    epoch_loss = 0

    kbar = pkbar.Kbar(val_batch_num)

    for batch, (padded_input, padded_target, padded_decoder, input_lens,
                target_lens) in enumerate(val_loader):

        with torch.no_grad():

            batch_size = len(input_lens)
            vocab_size = model.vocab_size
            max_len = max(target_lens)

            padded_input = padded_input.to(DEVICE)
            padded_target = padded_target.to(DEVICE)
            padded_decoder = padded_decoder.to(DEVICE)

            predictions = model(padded_input, input_lens, epoch,
                                padded_decoder)
            inferences = torch.argmax(predictions, dim=2)
            targets = padded_target

            mask = torch.arange(max_len).unsqueeze(0) < torch.tensor(
                target_lens).unsqueeze(1)
            mask = mask.type(torch.float64)
            mask = mask.reshape(batch_size * max_len).to(DEVICE)

            predictions = predictions.reshape(batch_size * max_len, vocab_size)
            padded_target = padded_target.reshape(batch_size * max_len)

            loss = criterion(predictions, padded_target)
            masked_loss = torch.sum(loss * mask)
            batch_loss = masked_loss / torch.sum(mask).item()
            epoch_loss += batch_loss.item()
            perplexity = np.exp(batch_loss.item())
            epoch_perplexity += perplexity

            cur_dis = 0
            for i, article in enumerate(inferences):
                inference = ''
                for k in article:
                    inference += index2letter[k.item()]
                    if index2letter[k.item()] == '<eos>':
                        break
                target = ''.join(index2letter[k.item()] for k in targets[i])
                if i == len(inferences) - 1 and batch == val_batch_num - 1:
                    print('\nInput text:\n', target[:150])
                    print('Pred text:\n', inference[:150])
                cur_dis += distance(inference, target)
            batch_dis = cur_dis / batch_size

            epoch_distance += batch_dis
            kbar.update(batch,
                        values=[("loss", batch_loss), ("Dis", batch_dis)])

    kbar.add(1)
    writer.add_scalar('Loss/Val', epoch_loss / val_batch_num, epoch)
    writer.add_scalar('Perplexity/Val', epoch_perplexity / val_batch_num,
                      epoch)
    writer.add_scalar('Distance/val', epoch_distance / val_batch_num, epoch)

    return epoch_loss / val_batch_num
Exemple #25
0
def train_transformations(model, data_loaders, args):
    """Train an emotion EBM."""
    device = args.device
    optimizer = Adam(model.parameters(), lr=args.lr, weight_decay=args.wd)
    model, optimizer, _, start_epoch, is_trained = load_from_ckpnt(
        args.classifier_ckpnt, model, optimizer, scheduler=None)
    if is_trained:
        return model
    writer = SummaryWriter('runs/' + args.checkpoint.replace('.pt', ''))

    # Training loop
    for epoch in range(start_epoch, args.epochs):
        print("Epoch: %d/%d" % (epoch + 1, args.epochs))
        kbar = pkbar.Kbar(target=len(data_loaders['train']), width=25)
        model.train()
        model.disable_batchnorm()
        model.zero_grad()
        # model.enable_grads()
        for step, ex in enumerate(data_loaders['train']):
            images, _, emotions, neg_images = ex
            # positive samples
            pos_samples = images.to(device)
            # prepare negative samples
            neg_samples = rand_augment(images.clone().to(device))
            # negative samples
            neg_ld_samples, neg_list = langevin_updates(
                model,
                torch.clone(neg_samples),
                args.langevin_steps,
                args.langevin_step_size,
            )
            # Compute energy
            pos_out = model(normalize_imagenet_rgb(pos_samples))
            neg_img_out = model(normalize_imagenet_rgb(neg_images.to(device)))
            neg_ld_out = model(
                normalize_imagenet_rgb(neg_ld_samples.to(device)))
            # Loss
            loss_reg = (pos_out**2 + neg_ld_out**2 + neg_img_out**2).mean()
            # loss_reg = (torch.abs(pos_out) + torch.abs(neg_ld_out) + torch.abs(neg_img_out)).mean()
            loss_ml = 2 * pos_out.mean() - neg_ld_out.mean(
            ) - neg_img_out.mean()
            coeff = loss_ml.detach().clone() / loss_reg.detach().clone()
            loss = 0.5 * loss_reg + loss_ml
            # if epoch == 0:
            #     loss = loss * 0.05
            '''
            loss = (
                pos_out**2 + neg_out**2 + neg_img_out**2 + neg_img_ld_out**2
                + 3*pos_out - neg_out - neg_img_out - neg_img_ld_out
            ).mean()
             '''
            # Step
            optimizer.zero_grad()
            loss.backward()
            clip_grad(model.parameters(), optimizer)
            optimizer.step()
            kbar.update(step, [("loss", loss)])
            # Log loss
            writer.add_scalar('energy/energy_pos',
                              pos_out.mean().item(),
                              epoch * len(data_loaders['train']) + step)
            writer.add_scalar('energy/energy_neg',
                              neg_ld_out.mean().item(),
                              epoch * len(data_loaders['train']) + step)
            writer.add_scalar('loss/loss_reg', loss_reg.item(),
                              epoch * len(data_loaders['train']) + step)
            writer.add_scalar('loss/loss_ml', loss_ml.item(),
                              epoch * len(data_loaders['train']) + step)
            writer.add_scalar('loss/loss_total', loss.item(),
                              epoch * len(data_loaders['train']) + step)
            # Log image evolution
            if step % 50 != 0:
                continue
            writer.add_image('ld/random_image_sample',
                             back2color(pos_samples[0]),
                             epoch * len(data_loaders['train']) + step)
            writer.add_image('ld/ld_start', back2color(neg_list[0]),
                             epoch * len(data_loaders['train']) + step)
            writer.add_image('ld/ld_end', back2color(neg_list[-1]),
                             epoch * len(data_loaders['train']) + step)
            neg_list = [back2color(neg) for neg in neg_list]
            neg_list = [torch.zeros_like(neg_list[0])] + neg_list
            vid_to_write = torch.stack(neg_list, dim=0).unsqueeze(0)
            writer.add_video('ebm_evolution',
                             vid_to_write,
                             fps=args.ebm_log_fps,
                             global_step=epoch * len(data_loaders['train']) +
                             step)
        writer.add_scalar('lr',
                          optimizer.state_dict()['param_groups'][0]['lr'],
                          epoch)
        # Save checkpoint
        torch.save(
            {
                "epoch": epoch + 1,
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict()
            }, args.classifier_ckpnt)
        torch.save(
            {
                "epoch": epoch + 1,
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict()
            }, "transformations_%02d.pt" % (epoch + 1))
        print('\nValidation')
        print(eval_transformations(model, data_loaders['test'], args))
    return model
Exemple #26
0
def benchmark(parser, dataset, use_cuda_extension, force_inference=False):
    synchronize = True
    timers = {}

    def time_func(key, func):
        timers[key] = 0

        def wrapper(*args, **kwargs):
            start = timeit.default_timer()
            ret = func(*args, **kwargs)
            sync_cuda(synchronize)
            end = timeit.default_timer()
            timers[key] += end - start
            return ret

        return wrapper

    Rcr = 5.2000e+00
    Rca = 3.5000e+00
    EtaR = torch.tensor([1.6000000e+01], device=parser.device)
    ShfR = torch.tensor([
        9.0000000e-01, 1.1687500e+00, 1.4375000e+00, 1.7062500e+00,
        1.9750000e+00, 2.2437500e+00, 2.5125000e+00, 2.7812500e+00,
        3.0500000e+00, 3.3187500e+00, 3.5875000e+00, 3.8562500e+00,
        4.1250000e+00, 4.3937500e+00, 4.6625000e+00, 4.9312500e+00
    ],
                        device=parser.device)
    Zeta = torch.tensor([3.2000000e+01], device=parser.device)
    ShfZ = torch.tensor([
        1.9634954e-01, 5.8904862e-01, 9.8174770e-01, 1.3744468e+00,
        1.7671459e+00, 2.1598449e+00, 2.5525440e+00, 2.9452431e+00
    ],
                        device=parser.device)
    EtaA = torch.tensor([8.0000000e+00], device=parser.device)
    ShfA = torch.tensor(
        [9.0000000e-01, 1.5500000e+00, 2.2000000e+00, 2.8500000e+00],
        device=parser.device)
    num_species = 4
    aev_computer = torchani.AEVComputer(Rcr, Rca, EtaR, ShfR, EtaA, Zeta, ShfA,
                                        ShfZ, num_species, use_cuda_extension)

    nn = torchani.ANIModel(build_network())
    model = torch.nn.Sequential(aev_computer, nn).to(parser.device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.000001)
    mse = torch.nn.MSELoss(reduction='none')

    # enable timers
    torchani.aev.cutoff_cosine = time_func('torchani.aev.cutoff_cosine',
                                           torchani.aev.cutoff_cosine)
    torchani.aev.radial_terms = time_func('torchani.aev.radial_terms',
                                          torchani.aev.radial_terms)
    torchani.aev.angular_terms = time_func('torchani.aev.angular_terms',
                                           torchani.aev.angular_terms)
    torchani.aev.compute_shifts = time_func('torchani.aev.compute_shifts',
                                            torchani.aev.compute_shifts)
    torchani.aev.neighbor_pairs = time_func('torchani.aev.neighbor_pairs',
                                            torchani.aev.neighbor_pairs)
    torchani.aev.neighbor_pairs_nopbc = time_func(
        'torchani.aev.neighbor_pairs_nopbc', torchani.aev.neighbor_pairs_nopbc)
    torchani.aev.triu_index = time_func('torchani.aev.triu_index',
                                        torchani.aev.triu_index)
    torchani.aev.cumsum_from_zero = time_func('torchani.aev.cumsum_from_zero',
                                              torchani.aev.cumsum_from_zero)
    torchani.aev.triple_by_molecule = time_func(
        'torchani.aev.triple_by_molecule', torchani.aev.triple_by_molecule)
    torchani.aev.compute_aev = time_func('torchani.aev.compute_aev',
                                         torchani.aev.compute_aev)
    model[0].forward = time_func('total', model[0].forward)
    model[1].forward = time_func('forward', model[1].forward)
    optimizer.step = time_func('optimizer.step', optimizer.step)

    print('=> start training')
    start = time.time()
    loss_time = 0
    force_time = 0

    for epoch in range(0, parser.num_epochs):

        print('Epoch: %d/%d' % (epoch + 1, parser.num_epochs))
        progbar = pkbar.Kbar(target=len(dataset) - 1, width=8)

        for i, properties in enumerate(dataset):
            species = properties['species'].to(parser.device)
            coordinates = properties['coordinates'].to(
                parser.device).float().requires_grad_(force_inference)
            true_energies = properties['energies'].to(parser.device).float()
            num_atoms = (species >= 0).sum(dim=1, dtype=true_energies.dtype)
            _, predicted_energies = model((species, coordinates))
            # TODO add sync after aev is done
            sync_cuda(synchronize)
            energy_loss = (mse(predicted_energies, true_energies) /
                           num_atoms.sqrt()).mean()
            if force_inference:
                sync_cuda(synchronize)
                force_coefficient = 0.1
                true_forces = properties['forces'].to(parser.device).float()
                force_start = time.time()
                try:
                    sync_cuda(synchronize)
                    forces = -torch.autograd.grad(predicted_energies.sum(),
                                                  coordinates,
                                                  create_graph=True,
                                                  retain_graph=True)[0]
                    sync_cuda(synchronize)
                except Exception as e:
                    alert('Error: {}'.format(e))
                    return
                force_time += time.time() - force_start
                force_loss = (mse(true_forces, forces).sum(dim=(1, 2)) /
                              num_atoms).mean()
                loss = energy_loss + force_coefficient * force_loss
                sync_cuda(synchronize)
            else:
                loss = energy_loss
            rmse = hartree2kcalmol(
                (mse(predicted_energies,
                     true_energies)).mean()).detach().cpu().numpy()
            progbar.update(i, values=[("rmse", rmse)])
            if not force_inference:
                sync_cuda(synchronize)
                loss_start = time.time()
                loss.backward()
                # print('2', coordinates.grad)
                sync_cuda(synchronize)
                loss_stop = time.time()
                loss_time += loss_stop - loss_start
                optimizer.step()
                sync_cuda(synchronize)

        checkgpu()
    sync_cuda(synchronize)
    stop = time.time()

    print('=> More detail about benchmark PER EPOCH')
    total_time = (stop - start) / parser.num_epochs
    loss_time = loss_time / parser.num_epochs
    force_time = force_time / parser.num_epochs
    opti_time = timers['optimizer.step'] / parser.num_epochs
    forward_time = timers['forward'] / parser.num_epochs
    aev_time = timers['total'] / parser.num_epochs
    print_timer('   Total AEV', aev_time)
    print_timer('   Forward', forward_time)
    print_timer('   Backward', loss_time)
    print_timer('   Force', force_time)
    print_timer('   Optimizer', opti_time)
    print_timer(
        '   Others', total_time - loss_time - aev_time - forward_time -
        opti_time - force_time)
    print_timer('   Epoch time', total_time)
def vs_net_train(args):
    train_path = args.train_h5
    val_path = args.val_h5
    NEPOCHS = args.epoch
    CASCADE = args.cascade
    LR = args.lr
    NBATCH = args.nb
    Res_name = args.Result_name
    device_num = args.device
    chpoint = args.checkpoint
    aug = args.aug
    zpad = args.zpad

    device = 'cuda:' + str(device_num)
    if zpad is False:
        print("input is from LORAKS")
        trainset = D.MAGIC_Dataset_LORAKS(train_path,
                                          augmentation=aug,
                                          verbosity=False)
        testset = D.MAGIC_Dataset_LORAKS(val_path,
                                         augmentation=False,
                                         verbosity=False)
    elif zpad is True:
        print("input is from Zero-Padding")
        trainset = D.MAGIC_Dataset_zpad(train_path,
                                        augmentation=aug,
                                        verbosity=False)
        testset = D.MAGIC_Dataset_zpad(val_path,
                                       augmentation=False,
                                       verbosity=False)

    trainloader = DataLoader(trainset,
                             batch_size=NBATCH,
                             shuffle=True,
                             pin_memory=True,
                             num_workers=0)
    valloader = DataLoader(testset,
                           batch_size=NBATCH,
                           shuffle=False,
                           pin_memory=True,
                           num_workers=0)

    dataloaders = {'train': trainloader, 'validation': valloader}

    net = network(alfa=None, beta=0.5, cascades=CASCADE)
    net = net.to(device)
    if chpoint is not None:
        print('Loading network from:', chpoint)
        net.load_state_dict(torch.load(chpoint))

    ########## Training ####################
    _im0, _true, _Sens, _X_kJVC, _mask = testset[13]

    _im0, _true, _Sens, _X_kJVC, _mask = _im0.unsqueeze(0).to(device), _true.unsqueeze(0).to(device), _Sens.unsqueeze(0).to(device),\
    _X_kJVC.unsqueeze(0).to(device), _mask.unsqueeze(0).to(device)

    criterion = torch.nn.L1Loss()

    liveloss = PlotLosses()
    optimizer = torch.optim.Adam(net.parameters(), lr=LR)
    #    print('Now Training the Network')
    #    pdb.set_trace()
    for epoch in range(NEPOCHS):
        print('Epoch', epoch + 1)
        logs = {}
        for phase in {'train', 'validation'}:
            if phase == 'train':
                kbar = pkbar.Kbar(target=len(trainloader), width=2)
                net.train()
            else:
                net.eval()

            running_loss = 0.0
            running_mse = 0.0

            iii = 0
            for im0, true, tSens, tX_kJVC, tmask in dataloaders[phase]:

                im0, true, tX_kJVC, tSens, tmask = im0.to(device,non_blocking=True), true.to(device,non_blocking=True), tX_kJVC.to(device,non_blocking=True),\
                                                   tSens.to(device,non_blocking=True), tmask.to(device,non_blocking=True)

                if phase == 'train':
                    out = net(im0, tX_kJVC, tmask, tSens)
                    loss = criterion(out, true)
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    running_loss = running_loss + loss.item() * im0.size(0)
                    prefix = ''
                    kbar.update(iii,
                                values=[('L', 100 * running_loss / (iii + 1))])
                    iii = iii + 1
                else:
                    with torch.no_grad():
                        prefix = 'val_'
                        out = net(im0, tX_kJVC, tmask, tSens)
                        loss = criterion(out, true)
                        running_loss = running_loss + loss.item() * im0.size(0)
#                  print('hello')

                epoch_loss = running_loss / len(dataloaders[phase].dataset)

                logs[prefix + 'Loss'] = epoch_loss * 100

        if epoch % 10 == 0:
            save_name = 'Result_' + Res_name + '/Val_Epoch_' + str(
                epoch) + '.jpg'
            show_output(net, _im0, _true, _X_kJVC, _Sens, _mask, save_name)
            file_name = 'models/' + Res_name + '/Weights_Epoch_' + str(epoch)

            print(' SAVING WEIGHTS : ' + file_name)
            torch.save(net.state_dict(), file_name)

            f = open("models/" + Res_name + "/Losses_graph.obj",
                     "wb")  # Saving Lossplot objects to pickle
            pickle.dump(liveloss, f)
            f.close()

        liveloss.update(logs)
        f = open("Loss_Logging.txt", "a")
        kbar.add(1,
                 values=[('Train', logs['Loss']), ('Val', logs['val_Loss'])])
        f.write("Epoch{} : Training Loss : {:.5f} & Validation Loss: {:.5f}\n".
                format(epoch, logs['Loss'], logs['val_Loss']))
        f.close()
def ni_to_tensor_CT_folder(foldr, as_patches=False, overwrite=False):
    #
    if as_patches:
        patches_fldr = foldr.parents[4] / "patches"
        if not patches_fldr:
            os.makedirs(patches_fldr)
    else:
        tensor_fldr = foldr.parent / "tensors"
        if not tensor_fldr.exists():
            os.makedirs(tensor_fldr)

    directory = os.listdir(foldr)
    vols = [f for f in directory if 'volume' in f]
    bar = pkbar.Kbar(target=len(vols), width=28)
    for idx in range(len(vols)):
        bar.update(idx)  #,values =[("Processing case:" ,vol_name)])
        vol_name = vols[idx]
        mask_name = vol_name.replace('volume', 'segmentation')

        if not as_patches:
            vol_outname = tensor_fldr / vol_name.replace(
                ".nii", ".pt")  # 4 quadrant given indexed suffixes and saved
            mask_outname = tensor_fldr / mask_name.replace(".nii", ".pt")

            if overwrite or not vol_outname.exists():
                img_np = np.array(nib.load(Path(foldr) / vol_name).dataobj)
                vol_norm_pt = torch.from_numpy(normalize_numpy(img_np)).float(
                ).unsqueeze(
                    0
                )  # read array ->normalize ->torch float tensor -> add channel dimension
                torch.save(vol_norm_pt, vol_outname)

            if overwrite or not mask_outname.exists():
                mask_np = np.array(nib.load(Path(foldr) / mask_name).dataobj)
                mask_pt = torch.from_numpy(
                    mask_np.astype(float)).unsqueeze(0).float()
                torch.save(mask_pt, mask_outname)
        else:  #save as patches -- very large dataset so this code below is untested
            img_np = np.array(nib.load(Path(foldr) / vol_name).dataobj)
            vol_norm_pt = torch.from_numpy(normalize_numpy(img_np)).float(
            ).unsqueeze(
                0
            )  # read array ->normalize ->torch float tensor -> add channel dimension
            img = [
                vol_norm_pt[:, 0:256, 0:256, :], vol_norm_pt[0, 256:512,
                                                             0:256, :],
                vol_norm_pt[:, 0:256, 256:512], vol_norm_pt[:, 256:512,
                                                            256:512]
            ]
            mask_np = np.array(nib.load(Path(foldr) / mask_name).dataobj)
            mask_pt = torch.from_numpy(
                mask_np.astype(float)).unsqueeze(0).float()
            mask = [
                mask_pt[:, 0:256, 0:256, :], mask_pt[:, 256:512, 0:256, :],
                mask_pt[:, 0:256, 256:512], mask_pt[:, 256:512, 256:512]
            ]

            for i in range(4):
                vol_outname = vol_name.replace(
                    ".nii", "_patch" + str(i) +
                    ".pt")  # 4 quadrant given indexed suffixes and saved
                vol_outname = patches_fldr / str(vol_outname)
                if overwrite or not vol_outname.exists():
                    torch.save(img[i], vol_outname)
                mask_outname = mask_name.replace(".nii",
                                                 "_patch" + str(i) + ".pt")
                mask_outname = patches_fldr / str(mask_outname)
                if overwrite or not mask_outname.exists():
                    torch.save(mask[i], mask_outname)
# learning_rate = 0.001 #FirstBatch
learning_rate = 0.0001  #Vcab_Recordings with clutter removal
# learning_rate = 0.005 #Vcab_Recordings
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
#%%
# CNN model training
loss_list = []
val_loss_list = []
iteration_count = 0
train_per_epoch = math.ceil(len(train_set) / batch_size)
val_per_epoch = math.ceil(len(val_set) / batch_size)

for epoch in range(num_epochs):
    sum_loss = 0
    sum_val_loss = 0
    kbar = pkbar.Kbar(target=train_per_epoch, width=8)
    for i, sample in enumerate(train_loader):
        x_train = sample["imagePower"].float().to(device)
        y_train = sample["label"].float().to(device)
        x_train = Variable(x_train.view(len(x_train), 1, 29, 29, 24))
        y_train = Variable(y_train)
        # Clear gradients
        optimizer.zero_grad()
        # Forward propagation
        outputs = model(x_train)
        # Calculate softmax and ross entropy loss
        loss = error(outputs, y_train)
        # Calculating gradients
        loss.backward()
        # Update parameters
        optimizer.step()
net = model
pdb.set_trace()
net.build(torch.zeros([64, 3, 32, 32], dtype=torch.float).to(device))

trainloader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=True)
train_loss = 0
correct = 0
total = 0
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(),
                      lr=0.05,
                      momentum=0.9,
                      weight_decay=5e-4)
kbar = pkbar.Kbar(target=len(trainloader),
                  epoch=0,
                  num_epochs=1,
                  width=8,
                  always_stateful=False)

for batch_idx, (inputs, targets) in enumerate(trainloader):
    print(batch_idx, inputs.shape)
    inputs, targets = inputs.to(device), targets.to(device)
    optimizer.zero_grad()
    outputs = net(inputs)
    loss = criterion(outputs, targets)
    loss.backward()
    optimizer.step()

    train_loss += loss.item()
    _, predicted = outputs.max(1)
    total += targets.size(0)