Exemplo n.º 1
0
def train_on_batch(model: Tree2Seq, criterion: nn.modules.loss,
                   optimizer: torch.optim, scheduler: torch.optim.lr_scheduler,
                   graph: dgl.BatchedDGLGraph, labels: List[str], params: Dict,
                   device: torch.device) -> Dict:
    model.train()

    root_indexes = get_root_indexes(graph).to(device)

    # Model step
    model.zero_grad()
    root_logits, ground_truth = model(graph, root_indexes, labels,
                                      params['teacher_force'], device)
    root_logits = root_logits[1:]
    ground_truth = ground_truth[1:]
    loss = criterion(root_logits.view(-1, root_logits.shape[-1]),
                     ground_truth.view(-1))
    loss.backward()
    nn.utils.clip_grad_norm_(model.parameters(), params['clip_norm'])
    optimizer.step()
    scheduler.step()

    # Calculate metrics
    prediction = model.predict(root_logits)
    batch_train_info = {
        'loss':
        loss.item(),
        'statistics':
        calculate_batch_statistics(
            ground_truth, prediction,
            [model.decoder.label_to_id[token] for token in [PAD, UNK, EOS]])
    }
    return batch_train_info
Exemplo n.º 2
0
def learning_rate_scheduling(validation: Dict[str, float],
                             scheduler: torch.optim.lr_scheduler) -> None:
    """
    Checks the validation loss and interacts with the learing rate
    scheduler
    """
    accuracy = 0
    for key in validation:
        avg = validation[key]['ap/iou=0.50:0.95/area=all/max_dets=100'].mean()
        accuracy += avg
    scheduler.step(accuracy)
    print("Scheduler: Best metric seen so far %f, number of bad epochs %i" %
          (scheduler.best, scheduler.num_bad_epochs))
Exemplo n.º 3
0
    def fit(self,
            epochs: int,
            train_dl: DataLoader,
            test_dl: DataLoader,
            criterion: torch.nn,
            optimizer: torch.optim,
            scheduler: torch.optim.lr_scheduler = None):

        train_losses = []
        eval_losses = []

        for epoch in tqdm(range(epochs), desc="Epochs"):
            # train
            self.train()
            batch_losses = []

            batches = len(train_dl)

            for batch_input in tqdm(train_dl,
                                    total=batches,
                                    desc="- Remaining batches"):

                batch_input = [x.to(self.device) for x in batch_input]

                input_ids, att_masks, labels = batch_input

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward + backward + optimize
                outputs = self(input_ids, att_masks)
                loss = criterion(outputs.squeeze(), labels)
                loss.backward()

                optimizer.step()

                if scheduler is not None: scheduler.step()

                batch_losses.append(loss.item())

            train_loss = np.mean(batch_losses)
            self.last_train_loss = train_loss

            # evaluate
            tqdm.write(f"Epoch: {epoch+1}")
            _, eval_loss = self.evaluate(test_dl, criterion)

            train_losses.append(train_loss)
            eval_losses.append(eval_loss)

        return train_losses, eval_losses
Exemplo n.º 4
0
def train_on_batch(
        model: Tree2Seq, criterion: nn.modules.loss, optimizer: torch.optim, scheduler: torch.optim.lr_scheduler,
        graph: dgl.DGLGraph, labels: torch.Tensor, clip_norm: int
) -> Dict:
    model.train()

    # Model step
    model.zero_grad()
    loss, prediction, batch_info = _forward_pass(model, graph, labels, criterion)
    batch_info['learning_rate'] = scheduler.get_last_lr()[0]
    loss.backward()
    nn.utils.clip_grad_value_(model.parameters(), clip_norm)
    optimizer.step()
    scheduler.step()
    del loss
    del prediction
    torch.cuda.empty_cache()

    return batch_info
Exemplo n.º 5
0
    def _train(self,
               criterion: typing.Callable,
               earlystopping: EarlyStopping,
               scheduler: torch.optim.lr_scheduler,
               optimizer: torch.optim.Optimizer,
               train_loader: torch.utils.data.DataLoader,
               valid_loader=typing.Union[None, torch.utils.data.DataLoader],
               verbose: bool = True):
        """
        Class function to train the Trainer object. After every training epoch, a validation can be performed.
        :param criterion: (torch.nn) loss class
        :param earlystopping: (EarlyStopping) custom class for doing early stopping
        :param scheduler: (torch.optim.lr_scheduler) for adjusting learning rates during training
        :param optimizer: (torch.optim.Optimizer)
        :param train_loader: (torch.utils.data.DataLoader) object for training
        :param valid_loader: (torch.utils.data.DataLoader) or None in case validation should be performed after every epoch
        :param verbose (bool): whether or not to print out on console. Default: True
        :return: training and validation metrics
        """
        print('Training the Neuraldecipher for {} epochs.'.format(
            self.trainparams['n_epochs']))
        train_count = 0
        test_count = 0

        # create modelpath savedirs
        logdir_path = os.path.join('../logs', self.trainparams['output_dir'])
        model_outpath = os.path.join('../models',
                                     self.trainparams['output_dir'])

        if not os.path.exists(logdir_path):
            os.makedirs(logdir_path)
        if not os.path.exists(model_outpath):
            os.makedirs(model_outpath)

        # create summary writer
        writer = SummaryWriter(log_dir=logdir_path)

        # turn model in train mode
        self.model = self.model.train()

        # saving arrays
        self.train_loss_array = []
        self.test_loss_array = []
        self.train_euclidean_array = []
        self.test_euclidean_array = []

        if len(criterion) == 2:
            # Motivate distance loss and cosine similarity
            a = 10
            weight_func = lambda x: (a**x - 1) / (a - 1)
            self.cosine_weight_loss = [
                weight_func(f / self.trainparams['n_epochs'])
                for f in range(self.trainparams['n_epochs'])
            ]

        for epoch in range(0, self.trainparams['n_epochs']):

            self.train_loss = 0.0
            self.test_loss = 0.0
            self.train_euclidean = 0.0
            self.test_euclidean = 0.0

            for step, batch in tqdm(enumerate(train_loader),
                                    total=len(train_loader)):

                ecfp_in = batch['ecfp'].to(device=self.device,
                                           dtype=torch.float32)
                cddd_out = batch['cddd'].to(device=self.device,
                                            dtype=torch.float32)

                cddd_predicted = self.model(ecfp_in)

                # hacky solution in case there are more criteria
                if len(criterion) == 1:  # only difference, e.g MSE or logcosh
                    # compute prediction and loss
                    loss = criterion[0](cddd_predicted, cddd_out)
                elif len(criterion) == 2:  # difference AND cosine loss
                    d_loss = criterion[0](cddd_predicted, cddd_out)
                    cosine_loss = 1 - criterion[1](cddd_predicted, cddd_out)
                    loss = d_loss + self.cosine_weight_loss[epoch] * cosine_loss

                batch_train_euclidean = l2_distance(
                    y_pred=cddd_predicted, y_true=cddd_out).data.item()
                self.train_euclidean += batch_train_euclidean

                # compute gradients and update weights
                optimizer.zero_grad()
                loss.backward()
                self.train_loss += loss.data.item()
                optimizer.step()

                writer.add_scalar(tag='Loss/train',
                                  scalar_value=loss.data.item(),
                                  global_step=train_count)
                writer.add_scalar(tag='Euclidean/train',
                                  scalar_value=batch_train_euclidean,
                                  global_step=train_count)

                train_count += 1
                if train_count % 500 == 0 and train_count != 0 and verbose:
                    tqdm.write('*' * 100)
                    tqdm.write(
                        'Epoch [%d/%d] Batch [%d/%d] Loss Train: %.4f Mean L2 Distance %.4f'
                        %
                        (epoch, self.trainparams['start_epoch'] +
                         self.trainparams['n_epochs'], step, len(train_loader),
                         loss.data.item(), batch_train_euclidean))
                    tqdm.write('*' * 100 + '\n')

            # learning rate scheduler at the end of the epoch
            self.train_loss /= len(train_loader)
            self.train_euclidean /= len(train_loader)
            if scheduler:
                scheduler.step(self.train_euclidean)

            # evaluation
            if valid_loader:
                if verbose:
                    tqdm.write('Epoch %d finished. Doing validation:' %
                               (epoch))
                writer, test_count = self._eval(criterion, valid_loader,
                                                writer, test_count, epoch)
                self.test_loss /= len(valid_loader)
                self.test_euclidean /= len(valid_loader)

                if verbose:
                    tqdm.write(
                        'Epoch [%d/%d] Loss Train: %.4f Euclidean Train: %.4f Loss Valid: %.4f Euclidean Valid: %.4f'
                        % (epoch, self.trainparams['n_epochs'],
                           self.train_loss, self.train_euclidean,
                           self.test_loss, self.test_euclidean))

                if earlystopping:
                    earlystopping(metric_val=self.test_euclidean,
                                  model=self.model,
                                  modelpath=model_outpath,
                                  epoch=epoch)
                    if earlystopping.early_stop:
                        print(
                            'Early stopping the training the NeuralDecipher Model on ECFP fingerprints  \
                        with radii {} and {} bit length. \n Results and models are saved at {} and {}.'
                            .format(self.trainparams['radii'],
                                    self.model.input_dim, logdir_path,
                                    model_outpath))
                        break

            ## array saving
            self.train_loss_array.append(self.train_loss)
            self.test_loss_array.append(self.test_loss)
            self.train_euclidean_array.append(self.train_euclidean)
            self.test_euclidean_array.append(self.test_euclidean)

        print(
            'Finished training the NeuralDecipher Model on ECFP fingerprints with radii {} and {} bit length. \n \
              Results and models are saved at {} and {}.'.format(
                self.trainparams['radii'], self.model.input_dim, logdir_path,
                model_outpath))

        ## model saving
        torch.save(
            self.model.state_dict(),
            os.path.join(model_outpath,
                         'final_model_{}.pt'.format(self.test_euclidean)))

        ## array saving
        json_array = {
            'train_loss': self.train_loss_array,
            'train_euclidean': self.train_euclidean_array,
            'test_loss': self.test_loss_array,
            'test_euclidean': self.test_euclidean_array
        }

        json_filepath = os.path.join(model_outpath, 'loss_metrics.json')
        with open(json_filepath, 'w') as f:
            json.dump(json_array, f)
def do_epoch(args: argparse.Namespace,
             train_loader: torch.utils.data.DataLoader, model: DDP,
             optimizer: torch.optim.Optimizer,
             scheduler: torch.optim.lr_scheduler, epoch: int,
             callback: VisdomLogger, iter_per_epoch: int,
             log_iter: int) -> Tuple[torch.tensor, torch.tensor]:
    loss_meter = AverageMeter()
    train_losses = torch.zeros(log_iter).to(dist.get_rank())
    train_mIous = torch.zeros(log_iter).to(dist.get_rank())

    iterable_train_loader = iter(train_loader)

    if main_process(args):
        bar = tqdm(range(iter_per_epoch))
    else:
        bar = range(iter_per_epoch)

    for i in bar:
        model.train()
        current_iter = epoch * len(train_loader) + i + 1

        images, gt = iterable_train_loader.next()
        images = images.to(dist.get_rank(), non_blocking=True)
        gt = gt.to(dist.get_rank(), non_blocking=True)

        loss = compute_loss(
            args=args,
            model=model,
            images=images,
            targets=gt.long(),
            num_classes=args.num_classes_tr,
        )

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if args.scheduler == 'cosine':
            scheduler.step()

        if i % args.log_freq == 0:
            model.eval()
            logits = model(images)
            intersection, union, target = intersectionAndUnionGPU(
                logits.argmax(1), gt, args.num_classes_tr, 255)
            if args.distributed:
                dist.all_reduce(loss)
                dist.all_reduce(intersection)
                dist.all_reduce(union)
                dist.all_reduce(target)

            allAcc = (intersection.sum() / (target.sum() + 1e-10))  # scalar
            mAcc = (intersection / (target + 1e-10)).mean()
            mIoU = (intersection / (union + 1e-10)).mean()
            loss_meter.update(loss.item() / dist.get_world_size())

            if main_process(args):
                if callback is not None:
                    t = current_iter / len(train_loader)
                    callback.scalar('loss_train_batch',
                                    t,
                                    loss_meter.avg,
                                    title='Loss')
                    callback.scalars(['mIoU', 'mAcc', 'allAcc'],
                                     t, [mIoU, mAcc, allAcc],
                                     title='Training metrics')
                    for index, param_group in enumerate(
                            optimizer.param_groups):
                        lr = param_group['lr']
                        callback.scalar('lr', t, lr, title='Learning rate')
                        break

                train_losses[int(i / args.log_freq)] = loss_meter.avg
                train_mIous[int(i / args.log_freq)] = mIoU

    if args.scheduler != 'cosine':
        scheduler.step()

    return train_mIous, train_losses
Exemplo n.º 7
0
def train_helper_with_gradients_no_update(
        model: torchvision.models.resnet.ResNet,
        dataloaders: Dict[str, torch.utils.data.DataLoader],
        dataset_sizes: Dict[str, int], criterion: torch.nn.modules.loss,
        optimizer: torch.optim, scheduler: torch.optim.lr_scheduler,
        num_epochs: int, writer: IO, train_order_writer: IO,
        device: torch.device, start_epoch: int, batch_size: int,
        save_interval: int, checkpoints_folder: Path, num_layers: int,
        classes: List[str], num_classes: int, grad_csv: Path) -> None:
    since = time.time()

    # Initialize all the tensors to be used in training and validation.
    # Do this outside the loop since it will be written over entirely at each
    # epoch and doesn't need to be reallocated each time.
    train_all_labels = torch.empty(size=(dataset_sizes["train"], ),
                                   dtype=torch.long).cpu()
    train_all_predicts = torch.empty(size=(dataset_sizes["train"], ),
                                     dtype=torch.long).cpu()
    val_all_labels = torch.empty(size=(dataset_sizes["val"], ),
                                 dtype=torch.long).cpu()
    val_all_predicts = torch.empty(size=(dataset_sizes["val"], ),
                                   dtype=torch.long).cpu()

    global_minibatch_counter = 0

    mag_writer = open(str(grad_csv), "w")
    mag_writer.write(
        "image_name,train_loss,layers_-1,layer_0,layer_60,layer_1,layer_20,layer_40,layer_59,conf,correct\n"
    )

    # Train for specified number of epochs.
    for epoch in range(0, num_epochs):

        # Training phase.
        model.train(mode=True)

        train_running_loss = 0.0
        train_running_corrects = 0
        epoch_minibatch_counter = 0

        # Train over all training data.
        for idx, (inputs, labels, paths) in enumerate(dataloaders["train"]):
            train_inputs = inputs.to(device=device)
            train_labels = labels.to(device=device)
            optimizer.zero_grad()

            # Forward and backpropagation.
            with torch.set_grad_enabled(mode=True):
                train_outputs = model(train_inputs)
                confs, train_preds = torch.max(train_outputs, dim=1)
                train_loss = criterion(input=train_outputs,
                                       target=train_labels)
                train_loss.backward(retain_graph=True)
                # optimizer.step()

                # batch_grads = torch.autograd.grad(train_loss, model.parameters(), retain_graph=True)
                # print(len(batch_grads))
                # for batch_grad in batch_grads:
                #     print(batch_grad.size())

                train_loss_npy = float(train_loss.detach().cpu().numpy())
                layer_num_to_mag = get_grad_magnitude(model)
                image_name = get_image_name(paths[0])
                conf = float(confs.detach().cpu().numpy())
                train_pred = int(train_preds.detach().cpu().numpy()[0])
                gt_label = int(train_labels.detach().cpu().numpy()[0])
                correct = 0
                if train_pred == gt_label:
                    correct = 1

                output_line = f"{image_name},{train_loss_npy:.4f},{layer_num_to_mag[-1]:.4f},{layer_num_to_mag[0]:.4f},{layer_num_to_mag[60]:.4f},{layer_num_to_mag[1]:.4f},{layer_num_to_mag[20]:.4f},{layer_num_to_mag[40]:.4f},{layer_num_to_mag[59]:.4f},{conf:.4f},{correct}\n"
                mag_writer.write(output_line)
                print(idx, output_line)
                # print(idx, image_name, train_loss_npy, conf, train_pred, gt_label)

            # Update training diagnostics.
            train_running_loss += train_loss.item() * train_inputs.size(0)
            train_running_corrects += torch.sum(
                train_preds == train_labels.data, dtype=torch.double)

            start = idx * batch_size
            end = start + batch_size

            train_all_labels[start:end] = train_labels.detach().cpu()
            train_all_predicts[start:end] = train_preds.detach().cpu()

            global_minibatch_counter += 1
            epoch_minibatch_counter += 1

            # if global_minibatch_counter % 1000 == 0:

            #     calculate_confusion_matrix(all_labels=train_all_labels.numpy(),
            #                             all_predicts=train_all_predicts.numpy(),
            #                             classes=classes,
            #                             num_classes=num_classes)

            #     # Store training diagnostics.
            #     train_loss = train_running_loss / (epoch_minibatch_counter * batch_size)
            #     train_acc = train_running_corrects / (epoch_minibatch_counter * batch_size)

            #     # Validation phase.
            #     model.train(mode=False)

            #     val_running_loss = 0.0
            #     val_running_corrects = 0

            #     # Feed forward over all the validation data.
            #     for idx, (val_inputs, val_labels, paths) in enumerate(dataloaders["val"]):
            #         val_inputs = val_inputs.to(device=device)
            #         val_labels = val_labels.to(device=device)

            #         # Feed forward.
            #         with torch.set_grad_enabled(mode=False):
            #             val_outputs = model(val_inputs)
            #             _, val_preds = torch.max(val_outputs, dim=1)
            #             val_loss = criterion(input=val_outputs, target=val_labels)

            #         # Update validation diagnostics.
            #         val_running_loss += val_loss.item() * val_inputs.size(0)
            #         val_running_corrects += torch.sum(val_preds == val_labels.data,
            #                                         dtype=torch.double)

            #         start = idx * batch_size
            #         end = start + batch_size

            #         val_all_labels[start:end] = val_labels.detach().cpu()
            #         val_all_predicts[start:end] = val_preds.detach().cpu()

            #     calculate_confusion_matrix(all_labels=val_all_labels.numpy(),
            #                             all_predicts=val_all_predicts.numpy(),
            #                             classes=classes,
            #                             num_classes=num_classes)

            #     # Store validation diagnostics.
            #     val_loss = val_running_loss / dataset_sizes["val"]
            #     val_acc = val_running_corrects / dataset_sizes["val"]

            #     if torch.cuda.is_available():
            #         torch.cuda.empty_cache()

            # Remaining things related to training.
            # if global_minibatch_counter % 200000 == 0 or global_minibatch_counter == 5:
            #     epoch_output_path = checkpoints_folder.joinpath(
            #         f"resnet{num_layers}_e{epoch}_mb{global_minibatch_counter}_va{val_acc:.5f}.pt")

            #     # Confirm the output directory exists.
            #     epoch_output_path.parent.mkdir(parents=True, exist_ok=True)

            #     # Save the model as a state dictionary.
            #     torch.save(obj={
            #         "model_state_dict": model.state_dict(),
            #         "optimizer_state_dict": optimizer.state_dict(),
            #         "scheduler_state_dict": scheduler.state_dict(),
            #         "epoch": epoch + 1
            #     }, f=str(epoch_output_path))

            # writer.write(f"{epoch},{global_minibatch_counter},{train_loss:.4f},"
            #             f"{train_acc:.4f},{val_loss:.4f},{val_acc:.4f}\n")

            # current_lr = None
            # for group in optimizer.param_groups:
            #     current_lr = group["lr"]

            # # Print the diagnostics for each epoch.
            # print(f"Epoch {epoch} with "
            #     f"mb {global_minibatch_counter} "
            #     f"lr {current_lr:.15f}: "
            #     f"t_loss: {train_loss:.4f} "
            #     f"t_acc: {train_acc:.4f} "
            #     f"v_loss: {val_loss:.4f} "
            #     f"v_acc: {val_acc:.4f}\n")

        scheduler.step()

        current_lr = None
        for group in optimizer.param_groups:
            current_lr = group["lr"]

    # Print training information at the end.
    print(f"\ntraining complete in "
          f"{(time.time() - since) // 60:.2f} minutes")
Exemplo n.º 8
0
def train(
        model: torch.nn.Module,
        dataloaders: dict,
        criterion: torch.nn.Module,
        optimizer,
        scheduler: torch.optim.lr_scheduler,
        epochs: int,
        device: str,
        writer=None,
        model_name: str = 'base'
) -> Tuple[torch.nn.Module, list, list]:
    """
    Function to train model with given loss function, optimizer and scheduler. It operates in two phases: train and
    validation to allow to record results for validation set as well.
    :param model:
    :param dataloaders:
    :param criterion:
    :param optimizer:
    :param scheduler:
    :param epochs:
    :param device:
    :param writer:
    :param model_name:
    :return:
    """
    since = time.time()
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    best_loss = 2e5
    total_loss_train, total_loss_val = [], []
    margin = criterion.margin
    for epoch in range(epochs):
        print('Epoch {}/{}'.format(epoch, epochs - 1))

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()  # Set model to evaluate mode

            running_loss = 0.0
            running_map = 0

            # Iterate over data.
            for idx, (data, labels) in enumerate(dataloaders[phase]):
                # Convert to tuple to avoid problems when unpacking value in model/loss forward call
                if not type(data) in (tuple, list):
                    data = (data,)
                data = tuple(d.to(device) for d in data)
                if len(labels) > 0:
                    labels = labels.to(device)
                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(*data)
                    # Convert to tuple to avoid problems when unpacking value in model/loss forward call
                    if not type(outputs) in (tuple, list):
                        outputs = (outputs,)
                    if len(labels) > 0:
                        loss_outputs = criterion(*outputs, labels)
                    else:
                        loss_outputs = criterion(*outputs)
                    if type(loss_outputs) in (tuple, list):
                        loss, num_triplets = loss_outputs
                    else:
                        loss = loss_outputs
                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                # statistics
                running_loss += loss.item() * data[0].size(0)
                # running_map +=
            if phase == 'train':
                scheduler.step()
            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            if writer:
                writer.add_scalar('Loss/train', epoch_loss, epoch)
            if phase == 'train':
                total_loss_train.append(epoch_loss)
            else:
                total_loss_val.append(epoch_loss)
                if epoch_loss < best_loss:
                    print("New best model found")
                    best_loss = epoch_loss
                    best_model_wts = copy.deepcopy(model.state_dict())
            # epoch_map = running_map.double() / len(dataloaders[phase].dataset)

            print('{} Loss: {:.4f}'.format(
                phase, epoch_loss))

    if not os.path.exists('output/'):
        os.makedirs('output/')
    torch.save(model, f'output/model_{model_name}_margin_{margin}.pt')
    losses = {'train_loss': total_loss_train, 'val_loss': total_loss_val}
    with open(f'losses_model_{model_name}_margin_{margin}.pickle', 'wb') as f:
        pickle.dump(losses, f)

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    # print('Best val mAP: {:4f}'.format(best_map))    # print('Best val mAP: {:4f}'.format(best_map))
    model.load_state_dict(best_model_wts)
    return model, total_loss_train, total_loss_val
def train_helper(model: torchvision.models.resnet.ResNet,
                 dataloaders: Dict[str, torch.utils.data.DataLoader],
                 dataset_sizes: Dict[str,
                                     int], criterion: torch.nn.modules.loss,
                 optimizer: torch.optim, scheduler: torch.optim.lr_scheduler,
                 num_epochs: int, writer: IO, train_order_writer: IO,
                 device: torch.device, start_epoch: int, batch_size: int,
                 save_interval: int, checkpoints_folder: Path, num_layers: int,
                 classes: List[str], num_classes: int) -> None:
    """
    Function for training ResNet.
    Args:
        model: ResNet model for training.
        dataloaders: Dataloaders for IO pipeline.
        dataset_sizes: Sizes of the training and validation dataset.
        criterion: Metric used for calculating loss.
        optimizer: Optimizer to use for gradient descent.
        scheduler: Scheduler to use for learning rate decay.
        start_epoch: Starting epoch for training.
        writer: Writer to write logging information.
        train_order_writer: Writer to write the order of training examples.
        device: Device to use for running model.
        num_epochs: Total number of epochs to train for.
        batch_size: Mini-batch size to use for training.
        save_interval: Number of epochs between saving checkpoints.
        checkpoints_folder: Directory to save model checkpoints to.
        num_layers: Number of layers to use in the ResNet model from [18, 34, 50, 101, 152].
        classes: Names of the classes in the dataset.
        num_classes: Number of classes in the dataset.
    """
    since = time.time()

    # Initialize all the tensors to be used in training and validation.
    # Do this outside the loop since it will be written over entirely at each
    # epoch and doesn't need to be reallocated each time.
    train_all_labels = torch.empty(size=(dataset_sizes["train"], ),
                                   dtype=torch.long).cpu()
    train_all_predicts = torch.empty(size=(dataset_sizes["train"], ),
                                     dtype=torch.long).cpu()
    val_all_labels = torch.empty(size=(dataset_sizes["val"], ),
                                 dtype=torch.long).cpu()
    val_all_predicts = torch.empty(size=(dataset_sizes["val"], ),
                                   dtype=torch.long).cpu()

    global_minibatch_counter = 0

    # Train for specified number of epochs.
    for epoch in range(start_epoch, num_epochs):

        # Training phase.
        model.train(mode=True)

        train_running_loss = 0.0
        train_running_corrects = 0
        epoch_minibatch_counter = 0

        # Train over all training data.
        for idx, (inputs, labels, paths) in enumerate(dataloaders["train"]):

            train_inputs = inputs.to(device=device)
            train_labels = labels.to(device=device)
            optimizer.zero_grad()

            # Forward and backpropagation.
            with torch.set_grad_enabled(mode=True):
                train_outputs = model(train_inputs)
                __, train_preds = torch.max(train_outputs, dim=1)
                train_loss = criterion(input=train_outputs,
                                       target=train_labels)
                train_loss.backward()
                optimizer.step()

            # Update training diagnostics.
            train_running_loss += train_loss.item() * train_inputs.size(0)
            train_running_corrects += torch.sum(
                train_preds == train_labels.data, dtype=torch.double)

            start = idx * batch_size
            end = start + batch_size

            train_all_labels[start:end] = train_labels.detach().cpu()
            train_all_predicts[start:end] = train_preds.detach().cpu()

            global_minibatch_counter += 1
            epoch_minibatch_counter += 1

            # for path in paths: #write the order that the model was trained in
            #     train_order_writer.write("/".join(path.split("/")[-2:]) + "\n")

            if global_minibatch_counter % 10 == 0 or global_minibatch_counter == 5:

                calculate_confusion_matrix(
                    all_labels=train_all_labels.numpy(),
                    all_predicts=train_all_predicts.numpy(),
                    classes=classes,
                    num_classes=num_classes)

                # Store training diagnostics.
                train_loss = train_running_loss / (epoch_minibatch_counter *
                                                   batch_size)
                train_acc = train_running_corrects / (epoch_minibatch_counter *
                                                      batch_size)

                # Validation phase.
                model.train(mode=False)

                val_running_loss = 0.0
                val_running_corrects = 0

                # Feed forward over all the validation data.
                for idx, (val_inputs, val_labels,
                          paths) in enumerate(dataloaders["val"]):
                    val_inputs = val_inputs.to(device=device)
                    val_labels = val_labels.to(device=device)

                    # Feed forward.
                    with torch.set_grad_enabled(mode=False):
                        val_outputs = model(val_inputs)
                        _, val_preds = torch.max(val_outputs, dim=1)
                        val_loss = criterion(input=val_outputs,
                                             target=val_labels)

                    # Update validation diagnostics.
                    val_running_loss += val_loss.item() * val_inputs.size(0)
                    val_running_corrects += torch.sum(
                        val_preds == val_labels.data, dtype=torch.double)

                    start = idx * batch_size
                    end = start + batch_size

                    val_all_labels[start:end] = val_labels.detach().cpu()
                    val_all_predicts[start:end] = val_preds.detach().cpu()

                calculate_confusion_matrix(
                    all_labels=val_all_labels.numpy(),
                    all_predicts=val_all_predicts.numpy(),
                    classes=classes,
                    num_classes=num_classes)

                # Store validation diagnostics.
                val_loss = val_running_loss / dataset_sizes["val"]
                val_acc = val_running_corrects / dataset_sizes["val"]

                if torch.cuda.is_available():
                    torch.cuda.empty_cache()

                # Remaining things related to training.
                if global_minibatch_counter % 10 == 0 or global_minibatch_counter == 5:
                    epoch_output_path = checkpoints_folder.joinpath(
                        f"resnet{num_layers}_e{epoch}_mb{global_minibatch_counter}_va{val_acc:.5f}.pt"
                    )

                    # Confirm the output directory exists.
                    epoch_output_path.parent.mkdir(parents=True, exist_ok=True)

                    # Save the model as a state dictionary.
                    torch.save(obj={
                        "model_state_dict": model.state_dict(),
                        "optimizer_state_dict": optimizer.state_dict(),
                        "scheduler_state_dict": scheduler.state_dict(),
                        "epoch": epoch + 1
                    },
                               f=str(epoch_output_path))

                writer.write(
                    f"{epoch},{global_minibatch_counter},{train_loss:.4f},"
                    f"{train_acc:.4f},{val_loss:.4f},{val_acc:.4f}\n")

                current_lr = None
                for group in optimizer.param_groups:
                    current_lr = group["lr"]

                # Print the diagnostics for each epoch.
                print(f"Epoch {epoch} with "
                      f"mb {global_minibatch_counter} "
                      f"lr {current_lr:.15f}: "
                      f"t_loss: {train_loss:.4f} "
                      f"t_acc: {train_acc:.4f} "
                      f"v_loss: {val_loss:.4f} "
                      f"v_acc: {val_acc:.4f}\n")

        scheduler.step()

        current_lr = None
        for group in optimizer.param_groups:
            current_lr = group["lr"]

    # Print training information at the end.
    print(f"\ntraining complete in "
          f"{(time.time() - since) // 60:.2f} minutes")
Exemplo n.º 10
0
def train(args,
          worker_id: int,
          global_model: Union[ActorNetwork, ActorCriticNetwork],
          T: Value,
          global_reward: Value,
          optimizer: torch.optim.Optimizer = None,
          global_model_critic: CriticNetwork = None,
          optimizer_critic: torch.optim.Optimizer = None,
          lr_scheduler: torch.optim.lr_scheduler = None,
          lr_scheduler_critic: torch.optim.lr_scheduler = None):
    """
    Start worker in training mode, i.e. training the shared model with backprop
    loosely based on https://github.com/ikostrikov/pytorch-a3c/blob/master/train.py
    :param args: console arguments
    :param worker_id: id of worker to differentiatethem and init different seeds
    :param global_model: global model, which is optimized/ for split models: actor
    :param T: global counter of steps
    :param global_reward: global running reward value
    :param optimizer: optimizer for shared model/ for split models: actor model
    :param global_model_critic: optional global critic model for split networks
    :param optimizer_critic: optional critic optimizer for split networks
    :param lr_scheduler: optional learning rate scheduler instance for shared model
    / for fixed model: actor learning rate scheduler
    :param lr_scheduler_critic: optional learning rate scheduler instance for critic model
    :return: None
    """
    torch.manual_seed(args.seed + worker_id)

    if args.worker == 1:
        logging.info(f"Running A2C with {args.n_envs} environments.")
        if "RR" not in args.env_name:
            env = SubprocVecEnv([
                make_env(args.env_name, args.seed, i, args.log_dir)
                for i in range(args.n_envs)
            ])
        else:
            env = DummyVecEnv(
                [make_env(args.env_name, args.seed, worker_id, args.log_dir)])
    else:
        logging.info(f"Running A3C: training worker {worker_id} started.")
        env = DummyVecEnv(
            [make_env(args.env_name, args.seed, worker_id, args.log_dir)])
        # avoid any issues if this is not 1
        args.n_envs = 1

    normalizer = get_normalizer(args.normalizer, env)

    # init local NN instance for worker thread
    model = copy.deepcopy(global_model)
    model.train()

    model_critic = None

    if global_model_critic:
        model_critic = copy.deepcopy(global_model_critic)
        model_critic.train()

    # if no shared optimizer is provided use individual one
    if not optimizer:
        optimizer, optimizer_critic = get_optimizer(
            args.optimizer,
            global_model,
            args.lr,
            model_critic=global_model_critic,
            lr_critic=args.lr_critic)
        if args.lr_scheduler == "exponential":
            lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer,
                                                                  gamma=0.99)
            if optimizer_critic:
                lr_scheduler_critic = torch.optim.lr_scheduler.ExponentialLR(
                    optimizer_critic, gamma=0.99)

    state = torch.Tensor(env.reset())

    t = np.zeros(args.n_envs)
    global_iter = 0
    episode_reward = np.zeros(args.n_envs)

    if worker_id == 0:
        writer = SummaryWriter(log_dir='experiments/runs/')

    while True:
        # Get state of the global model
        model.load_state_dict(global_model.state_dict())
        if not args.shared_model:
            model_critic.load_state_dict(global_model_critic.state_dict())

        # containers for computing loss
        values = []
        log_probs = []
        rewards = []
        entropies = []
        # container to check whether a terminal state was reached from one of the envs
        terminals = []

        # reward_sum = 0
        for step in range(args.rollout_steps):
            t += 1

            if args.shared_model:
                value, mu, std = model(normalizer(state))
            else:
                mu, std = model(normalizer(state))
                value = model_critic(normalizer(state))

            dist = torch.distributions.Normal(mu, std)

            # ------------------------------------------
            # # select action
            action = dist.sample()

            # ------------------------------------------
            # Compute statistics for loss
            entropy = dist.entropy().sum(-1).unsqueeze(-1)
            log_prob = dist.log_prob(action).sum(-1).unsqueeze(-1)

            # make selected move
            action = np.clip(action.detach().numpy(), -args.max_action,
                             args.max_action)
            state, reward, dones, _ = env.step(
                action[0]
                if not args.worker == 1 or "RR" in args.env_name else action)

            reward = shape_reward(args, reward)

            episode_reward += reward

            # probably don't set terminal state if max_episode length
            dones = np.logical_or(dones, t >= args.max_episode_length)

            values.append(value)
            log_probs.append(log_prob)
            rewards.append(torch.Tensor(reward).unsqueeze(-1))
            entropies.append(entropy)
            terminals.append(torch.Tensor(1 - dones).unsqueeze(-1))

            for i, done in enumerate(dones):
                if done:
                    # keep track of the avg overall global reward
                    with global_reward.get_lock():
                        if global_reward.value == -np.inf:
                            global_reward.value = episode_reward[i]
                        else:
                            global_reward.value = .99 * global_reward.value + .01 * episode_reward[
                                i]
                    if worker_id == 0 and T.value % args.log_frequency == 0:
                        writer.add_scalar("reward/global", global_reward.value,
                                          T.value)

                    episode_reward[i] = 0
                    t[i] = 0
                    if args.worker != 1 or "RR" in args.env_name:
                        env.reset()

            with T.get_lock():
                # this is one for a3c and n for A2C (actually the lock is not needed for A2C)
                T.value += args.n_envs

            if lr_scheduler and worker_id == 0 and T.value % args.lr_scheduler_step and global_iter != 0:
                lr_scheduler.step(T.value / args.lr_scheduler_step)

                if lr_scheduler_critic:
                    lr_scheduler_critic.step(T.value / args.lr_scheduler_step)

            state = torch.Tensor(state)

        if args.shared_model:
            v, _, _ = model(normalizer(state))
            G = v.detach()
        else:
            G = model_critic(normalizer(state)).detach()

        values.append(G)

        # compute loss and backprop
        advantages = torch.zeros((args.n_envs, 1))

        ret = torch.zeros((args.rollout_steps, args.n_envs, 1))
        adv = torch.zeros((args.rollout_steps, args.n_envs, 1))

        # iterate over all time steps from most recent to the starting one
        for i in reversed(range(args.rollout_steps)):
            # G can be seen essentially as the return over the course of the rollout
            G = rewards[i] + args.discount * terminals[i] * G
            if not args.no_gae:
                # Generalized Advantage Estimation
                td_error = rewards[i] + args.discount * terminals[i] * values[
                    i + 1] - values[i]
                # terminals here to "reset" advantages to 0, because reset ist called internally in the env
                # and new trajectory started
                advantages = advantages * args.discount * args.tau * terminals[
                    i] + td_error
            else:
                advantages = G - values[i].detach()

            adv[i] = advantages.detach()
            ret[i] = G.detach()

        policy_loss = -(torch.stack(log_probs) * adv).mean()
        # minus 1 in order to remove the last element, which is only necessary for next timestep value
        value_loss = .5 * (ret - torch.stack(values[:-1])).pow(2).mean()
        entropy_loss = torch.stack(entropies).mean()

        # zero grads to reset the gradients
        optimizer.zero_grad()

        if args.shared_model:
            # combined loss for shared architecture
            total_loss = policy_loss + args.value_loss_weight * value_loss - args.entropy_loss_weight * entropy_loss
            total_loss.backward()
        else:
            optimizer_critic.zero_grad()

            value_loss.backward()
            (policy_loss - args.entropy_loss_weight * entropy_loss).backward()

            # this is just used for plotting in tensorboard
            total_loss = policy_loss + args.value_loss_weight * value_loss - args.entropy_loss_weight * entropy_loss

        torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
        sync_grads(model, global_model)
        optimizer.step()

        if not args.shared_model:
            torch.nn.utils.clip_grad_norm_(model_critic.parameters(),
                                           args.max_grad_norm)
            sync_grads(model_critic, global_model_critic)
            optimizer_critic.step()

        global_iter += 1

        if worker_id == 0 and T.value % args.log_frequency == 0:
            log_to_tensorboard(writer,
                               model,
                               optimizer,
                               rewards,
                               values,
                               total_loss,
                               policy_loss,
                               value_loss,
                               entropy_loss,
                               T.value,
                               model_critic=model_critic,
                               optimizer_critic=optimizer_critic)
Exemplo n.º 11
0
    def _train_helper(self, model: torchvision.models.resnet.ResNet,
                      dataloaders: Dict[str, torch.utils.data.DataLoader],
                      dataset_sizes: Dict[str, int], loss_fn,
                      optimizer: torch.optim,
                      scheduler: torch.optim.lr_scheduler, start_epoch: int,
                      writer: IO) -> None:
        """
        Function for learning ResNet.

        Args:
            model: ResNet model for learning.
            dataloaders: Dataloaders for IO pipeline.
            dataset_sizes: Sizes of the learning and validation dataset.
            loss_fn: Metric used for calculating loss.
            optimizer: Optimizer to use for gradient descent.
            scheduler: Scheduler to use for learning rate decay.
            start_epoch: Starting epoch for learning.
            writer: Writer to write logging information.
        """
        learning_init_time = time.time()

        # Initialize all the tensors to be used in learning and validation.
        # Do this outside the loop since it will be written over entirely at each
        # epoch and doesn't need to be reallocated each time.
        train_all_labels = torch.empty(size=(dataset_sizes["train"], ),
                                       dtype=torch.long).cpu()
        train_all_predicts = torch.empty(size=(dataset_sizes["train"], ),
                                         dtype=torch.long).cpu()
        val_all_labels = torch.empty(size=(dataset_sizes["val"], ),
                                     dtype=torch.long).cpu()
        val_all_predicts = torch.empty(size=(dataset_sizes["val"], ),
                                       dtype=torch.long).cpu()
        early_stopper = EarlyStopper(patience=self._early_stopping_patience,
                                     mode=EarlyStopper.Mode.MAX)

        if self._resume_checkpoint and self._last_val_acc:
            best_val_acc = self._last_val_acc
        else:
            best_val_acc = 0.

        # Train for specified number of epochs.
        for epoch in range(start_epoch, self._num_epochs):
            epoch_init_time = time.time()

            # Training phase.
            model.train(mode=True)

            train_running_loss = 0.0
            train_running_corrects = 0

            # Train over all learning data.
            for idx, (train_inputs,
                      true_labels) in enumerate(dataloaders["train"]):
                train_patches = train_inputs["patch"].to(device=self._device)
                train_x_coord = train_inputs["x_coord"].to(device=self._device)
                train_y_coord = train_inputs["y_coord"].to(device=self._device)
                true_labels = true_labels.to(device=self._device)
                optimizer.zero_grad()

                # Forward and backpropagation.
                with torch.set_grad_enabled(mode=True):
                    train_logits = model(train_patches, train_x_coord,
                                         train_y_coord).squeeze(dim=1)
                    train_loss = loss_fn(logits=train_logits,
                                         target=true_labels)
                    train_loss.backward()
                    optimizer.step()

                # Update learning diagnostics.
                train_running_loss += train_loss.item() * train_patches.size(0)
                pred_labels = self._extract_pred_labels(train_logits)
                train_running_corrects += torch.sum(
                    pred_labels == true_labels.data, dtype=torch.double)

                start = idx * self._batch_size
                end = start + self._batch_size

                train_all_labels[start:end] = true_labels.detach().cpu()
                train_all_predicts[start:end] = pred_labels.detach().cpu()

            self._calculate_confusion_matrix(
                all_labels=train_all_labels.numpy(),
                all_predicts=train_all_predicts.numpy(),
                classes=self._classes,
                num_classes=self._num_classes)

            # Store learning diagnostics.
            train_loss = train_running_loss / dataset_sizes["train"]
            train_acc = train_running_corrects / dataset_sizes["train"]

            if torch.cuda.is_available():
                torch.cuda.empty_cache()

            # Validation phase.
            model.train(mode=False)

            val_running_loss = 0.0
            val_running_corrects = 0

            # Feed forward over all the validation data.
            for idx, (val_inputs, val_labels) in enumerate(dataloaders["val"]):
                val_patches = val_inputs["patch"].to(device=self._device)
                val_x_coord = val_inputs["x_coord"].to(device=self._device)
                val_y_coord = val_inputs["y_coord"].to(device=self._device)
                val_labels = val_labels.to(device=self._device)

                # Feed forward.
                with torch.set_grad_enabled(mode=False):
                    val_logits = model(val_patches, val_x_coord,
                                       val_y_coord).squeeze(dim=1)
                    val_loss = loss_fn(logits=val_logits, target=val_labels)

                # Update validation diagnostics.
                val_running_loss += val_loss.item() * val_patches.size(0)
                pred_labels = self._extract_pred_labels(val_logits)
                val_running_corrects += torch.sum(
                    pred_labels == val_labels.data, dtype=torch.double)

                start = idx * self._batch_size
                end = start + self._batch_size

                val_all_labels[start:end] = val_labels.detach().cpu()
                val_all_predicts[start:end] = pred_labels.detach().cpu()

            self._calculate_confusion_matrix(
                all_labels=val_all_labels.numpy(),
                all_predicts=val_all_predicts.numpy(),
                classes=self._classes,
                num_classes=self._num_classes)

            # Store validation diagnostics.
            val_loss = val_running_loss / dataset_sizes["val"]
            val_acc = val_running_corrects / dataset_sizes["val"]

            if torch.cuda.is_available():
                torch.cuda.empty_cache()

            scheduler.step()

            current_lr = None
            for group in optimizer.param_groups:
                current_lr = group["lr"]

            # Remaining things related to learning.
            if val_acc > best_val_acc:
                best_val_acc = val_acc
                best_model_ckpt_path = self._checkpoints_folder.joinpath(
                    f"resnet{self._num_layers}_e{epoch}_va{val_acc:.5f}.pt")

                # Confirm the output directory exists.
                best_model_ckpt_path.parent.mkdir(parents=True, exist_ok=True)

                # Save the model as a state dictionary.
                torch.save(obj={
                    "model_state_dict": model.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                    "scheduler_state_dict": scheduler.state_dict(),
                    "epoch": epoch + 1
                },
                           f=str(best_model_ckpt_path))

                self._clean_ckpt_folder(best_model_ckpt_path)

            writer.write(f"{epoch},{train_loss:.4f},"
                         f"{train_acc:.4f},{val_loss:.4f},{val_acc:.4f}\n")

            # Print the diagnostics for each epoch.
            logging.info(
                f"Epoch {epoch} "
                f"with lr {current_lr:.15f}: "
                f"{self._format_time_period(epoch_init_time, time.time())} "
                f"t_loss: {train_loss:.4f} "
                f"t_acc: {train_acc:.4f} "
                f"v_loss: {val_loss:.4f} "
                f"v_acc: {val_acc:.4f}\n")

            early_stopper.update(val_acc)
            if early_stopper.is_stopping():
                logging.info("Early stopping")
                break

        # Print learning information at the end.
        logging.info(
            f"\nlearning complete in "
            f"{self._format_time_period(learning_init_time, time.time())}")
Exemplo n.º 12
0
def train_helper(model: torchvision.models.resnet.ResNet,
                 dataloaders: Dict[str, torch.utils.data.DataLoader],
                 dataset_sizes: Dict[str,
                                     int], criterion: torch.nn.modules.loss,
                 optimizer: torch.optim, scheduler: torch.optim.lr_scheduler,
                 num_epochs: int, log_writer: IO, train_order_writer: IO,
                 device: torch.device, batch_size: int,
                 checkpoints_folder: Path, num_layers: int, classes: List[str],
                 minibatch_counter, num_classes: int) -> None:

    since = time.time()
    global_minibatch_counter = minibatch_counter
    # Initialize all the tensors to be used in training and validation.
    # Do this outside the loop since it will be written over entirely at each
    # epoch and doesn't need to be reallocated each time.
    train_all_labels = torch.empty(size=(dataset_sizes["train"], ),
                                   dtype=torch.long).cpu()
    train_all_predicts = torch.empty(size=(dataset_sizes["train"], ),
                                     dtype=torch.long).cpu()
    val_all_labels = torch.empty(size=(dataset_sizes["val"], ),
                                 dtype=torch.long).cpu()
    val_all_predicts = torch.empty(size=(dataset_sizes["val"], ),
                                   dtype=torch.long).cpu()

    for epoch in range(1, num_epochs + 1):

        model.train(mode=True)  # Training phase.
        train_running_loss, train_running_corrects, epoch_minibatch_counter = 0.0, 0, 0

        for idx, (inputs, labels, paths) in enumerate(dataloaders["train"]):
            train_inputs = inputs.to(device=device)
            train_labels = labels.to(device=device)
            optimizer.zero_grad()

            # Forward and backpropagation.
            with torch.set_grad_enabled(mode=True):
                train_outputs = model(train_inputs)
                __, train_preds = torch.max(train_outputs, dim=1)
                train_loss = criterion(input=train_outputs,
                                       target=train_labels)
                train_loss.backward()
                optimizer.step()

            # Update training diagnostics.
            train_running_loss += train_loss.item() * train_inputs.size(0)
            train_running_corrects += torch.sum(
                train_preds == train_labels.data, dtype=torch.double)

            this_batch_size = train_labels.detach().cpu().shape[0]
            start = idx * batch_size
            end = start + this_batch_size
            train_all_labels[start:end] = train_labels.detach().cpu()
            train_all_predicts[start:end] = train_preds.detach().cpu()

            global_minibatch_counter += 1
            epoch_minibatch_counter += 1

        # Calculate training diagnostics
        calculate_confusion_matrix(all_labels=train_all_labels.numpy(),
                                   all_predicts=train_all_predicts.numpy(),
                                   classes=classes,
                                   num_classes=num_classes)
        train_loss = train_running_loss / (epoch_minibatch_counter *
                                           batch_size)
        train_acc = train_running_corrects / (epoch_minibatch_counter *
                                              batch_size)

        # Validation phase.
        model.train(mode=False)
        val_running_loss = 0.0
        val_running_corrects = 0

        # Feed forward over all the validation data.
        for idx, (val_inputs, val_labels,
                  paths) in enumerate(dataloaders["val"]):
            val_inputs = val_inputs.to(device=device)
            val_labels = val_labels.to(device=device)

            # Feed forward.
            with torch.set_grad_enabled(mode=False):
                val_outputs = model(val_inputs)
                _, val_preds = torch.max(val_outputs, dim=1)
                val_loss = criterion(input=val_outputs, target=val_labels)

            # Update validation diagnostics.
            val_running_loss += val_loss.item() * val_inputs.size(0)
            val_running_corrects += torch.sum(val_preds == val_labels.data,
                                              dtype=torch.double)

            this_batch_size = val_labels.detach().cpu().shape[0]
            start = idx * batch_size
            end = start + this_batch_size
            val_all_labels[start:end] = val_labels.detach().cpu()
            val_all_predicts[start:end] = val_preds.detach().cpu()

        # Calculate validation diagnostics
        calculate_confusion_matrix(all_labels=val_all_labels.numpy(),
                                   all_predicts=val_all_predicts.numpy(),
                                   classes=classes,
                                   num_classes=num_classes)
        val_loss = val_running_loss / dataset_sizes["val"]
        val_acc = val_running_corrects / dataset_sizes["val"]

        if torch.cuda.is_available():
            torch.cuda.empty_cache()

        # Remaining things related to training.

        epoch_output_path = checkpoints_folder.joinpath(
            f"resnet{num_layers}_e{epoch}_mb{global_minibatch_counter}_va{val_acc:.5f}.pt"
        )
        epoch_output_path.parent.mkdir(parents=True, exist_ok=True)

        # Save the model as a state dictionary.
        torch.save(obj={
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "scheduler_state_dict": scheduler.state_dict(),
            "epoch": epoch + 1
        },
                   f=str(epoch_output_path))

        log_writer.write(
            f"{epoch},{global_minibatch_counter},{train_loss:.4f},{train_acc:.4f},{val_loss:.4f},{val_acc:.4f}\n"
        )

        current_lr = None
        for group in optimizer.param_groups:
            current_lr = group["lr"]

        # Print the diagnostics for each epoch.
        print(f"Epoch {epoch} with "
              f"mb {global_minibatch_counter} "
              f"lr {current_lr:.15f}: "
              f"t_loss: {train_loss:.4f} "
              f"t_acc: {train_acc:.4f} "
              f"v_loss: {val_loss:.4f} "
              f"v_acc: {val_acc:.4f}\n")

        scheduler.step()

        current_lr = None
        for group in optimizer.param_groups:
            current_lr = group["lr"]

    # Print training information at the end.
    print(f"\ntraining complete in "
          f"{(time.time() - since) // 60:.2f} minutes")

    return epoch_output_path, global_minibatch_counter
def train_smartgrad_helper(model: torchvision.models.resnet.ResNet,
                 dataloaders: Dict[str, torch.utils.data.DataLoader],
                 dataset_sizes: Dict[str, int],
                 criterion: torch.nn.modules.loss, 
                 optimizer: torch.optim,
                 scheduler: torch.optim.lr_scheduler, 
                 num_epochs: int,
                 log_writer: IO, 
                 train_order_writer: IO, 
                 device: torch.device, 
                 train_batch_size: int,
                 val_batch_size: int,
                 fake_minibatch_size: int, 
                 annealling_factor: float,
                 save_mb_interval: int, 
                 val_mb_interval: int,
                 checkpoints_folder: Path,
                 num_layers: int, 
                 classes: List[str],
                 num_classes: int) -> None:

    grad_layers = list(range(1, 21))

    since = time.time()
    global_minibatch_counter = 0
    # Initialize all the tensors to be used in training and validation.
    # Do this outside the loop since it will be written over entirely at each
    # epoch and doesn't need to be reallocated each time.
    train_all_labels = torch.empty(size=(dataset_sizes["train"], ),
                                   dtype=torch.long).cpu()
    train_all_predicts = torch.empty(size=(dataset_sizes["train"], ),
                                     dtype=torch.long).cpu()
    val_all_labels = torch.empty(size=(dataset_sizes["val"], ),
                                 dtype=torch.long).cpu()
    val_all_predicts = torch.empty(size=(dataset_sizes["val"], ),
                                   dtype=torch.long).cpu()

    for epoch in range(1, num_epochs+1):

        model.train(mode=False) # Training phase.
        train_running_loss, train_running_corrects, epoch_minibatch_counter = 0.0, 0, 0
        idx_to_gt = {}
        
        for idx, (inputs, labels, paths) in enumerate(dataloaders["train"]):
            train_inputs = inputs.to(device=device)
            train_labels = labels.to(device=device)
            optimizer.zero_grad()

            # Forward and backpropagation.
            with torch.set_grad_enabled(mode=True):
                train_outputs = model(train_inputs)
                __, train_preds = torch.max(train_outputs, dim=1)
                train_loss = criterion(input=train_outputs, target=train_labels)
                train_loss.backward(retain_graph=True)

                gt_label = int(train_labels.detach().cpu().numpy()[0])
                idx_to_gt[idx] = gt_label

                ########################
                #### important code ####
                ########################

                #clear the memory
                fake_minibatch_idx = idx % fake_minibatch_size
                fake_minibatch_num = int(idx / fake_minibatch_size)
                if fake_minibatch_idx == 0:
                    minibatch_grad_dict = {}; gc.collect()
                
                #get the per-example gradient magnitude and add to minibatch_grad_dict
                grad_as_dict, grad_flattened = model_to_grad_as_dict_and_flatten(model, grad_layers)
                minibatch_grad_dict[idx] = (grad_as_dict, grad_flattened)

                #every batch, calculate the best ones
                if fake_minibatch_idx == fake_minibatch_size - 1:
                    idx_to_weight_batch = get_idx_to_weight(minibatch_grad_dict, annealling_factor, idx_to_gt)
                    print(idx_to_weight_batch)

                    ##########################
                    # print("\n...............................updating......................................" + str(idx))
                    for layer_num, param in enumerate(model.parameters()):
                        # if layer_num in [0]:#grad_layers:
                        new_grad = get_new_layer_grad(layer_num, idx_to_weight_batch, minibatch_grad_dict)
                        assert param.grad.detach().cpu().numpy().shape == new_grad.detach().cpu().numpy().shape
                        param.grad = new_grad
                            # check_model_weights(idx, model)
                    optimizer.step()
                    # check_model_weights(idx, model)
                    # print("................................done........................................." + str(idx) + '\n\n\n\n')
                    ##########################

            # Update training diagnostics.
            train_running_loss += train_loss.item() * train_inputs.size(0)
            train_running_corrects += torch.sum(train_preds == train_labels.data, dtype=torch.double)

            start = idx * train_batch_size
            end = start + train_batch_size
            train_all_labels[start:end] = train_labels.detach().cpu()
            train_all_predicts[start:end] = train_preds.detach().cpu()

            global_minibatch_counter += 1
            epoch_minibatch_counter += 1

            # Write the path of training order if it exists
            if train_order_writer:
                for path in paths: #write the order that the model was trained in
                    train_order_writer.write("/".join(path.split("/")[-2:]) + "\n")

            # Validate the model
            if global_minibatch_counter % val_mb_interval == 0 or global_minibatch_counter == 1:

                # Calculate training diagnostics
                calculate_confusion_matrix( all_labels=train_all_labels.numpy(), all_predicts=train_all_predicts.numpy(),
                                            classes=classes, num_classes=num_classes)
                train_loss = train_running_loss / (epoch_minibatch_counter * train_batch_size)
                train_acc = train_running_corrects / (epoch_minibatch_counter * train_batch_size)

                # Validation phase.
                model.train(mode=False)
                val_running_loss = 0.0
                val_running_corrects = 0

                # Feed forward over all the validation data.
                for idx, (val_inputs, val_labels, paths) in enumerate(dataloaders["val"]):
                    val_inputs = val_inputs.to(device=device)
                    val_labels = val_labels.to(device=device)

                    # Feed forward.
                    with torch.set_grad_enabled(mode=False):
                        val_outputs = model(val_inputs)
                        _, val_preds = torch.max(val_outputs, dim=1)
                        val_loss = criterion(input=val_outputs, target=val_labels)

                    # Update validation diagnostics.
                    val_running_loss += val_loss.item() * val_inputs.size(0)
                    val_running_corrects += torch.sum(val_preds == val_labels.data,
                                                    dtype=torch.double)

                    start = idx * val_batch_size
                    end = start + val_batch_size
                    val_all_labels[start:end] = val_labels.detach().cpu()
                    val_all_predicts[start:end] = val_preds.detach().cpu()

                # Calculate validation diagnostics
                calculate_confusion_matrix( all_labels=val_all_labels.numpy(), all_predicts=val_all_predicts.numpy(),
                                            classes=classes, num_classes=num_classes)
                val_loss = val_running_loss / dataset_sizes["val"]
                val_acc = val_running_corrects / dataset_sizes["val"]

                if torch.cuda.is_available():
                    torch.cuda.empty_cache()
                    

                # Remaining things related to training.
                if global_minibatch_counter % save_mb_interval == 0 or global_minibatch_counter == 1:

                    epoch_output_path = checkpoints_folder.joinpath(f"resnet{num_layers}_e{epoch}_mb{global_minibatch_counter}_va{val_acc:.5f}.pt")
                    epoch_output_path.parent.mkdir(parents=True, exist_ok=True)

                    # Save the model as a state dictionary.
                    torch.save(obj={
                        "model_state_dict": model.state_dict(),
                        "optimizer_state_dict": optimizer.state_dict(),
                        "scheduler_state_dict": scheduler.state_dict(),
                        "epoch": epoch + 1
                    }, f=str(epoch_output_path))

                log_writer.write(f"{epoch},{global_minibatch_counter},{train_loss:.4f},{train_acc:.4f},{val_loss:.4f},{val_acc:.4f}\n")

                current_lr = None
                for group in optimizer.param_groups:
                    current_lr = group["lr"]

                # Print the diagnostics for each epoch.
                print(f"Epoch {epoch} with "
                    f"mb {global_minibatch_counter} "
                    f"lr {current_lr:.15f}: "
                    f"t_loss: {train_loss:.4f} "
                    f"t_acc: {train_acc:.4f} "
                    f"v_loss: {val_loss:.4f} "
                    f"v_acc: {val_acc:.4f}\n")

        scheduler.step()

        current_lr = None
        for group in optimizer.param_groups:
            current_lr = group["lr"]
def train(
    model: TEDD1104,
    optimizer_name: str,
    optimizer: torch.optim,
    scheduler: torch.optim.lr_scheduler,
    scaler: GradScaler,
    train_dir: str,
    dev_dir: str,
    test_dir: str,
    output_dir: str,
    batch_size: int,
    accumulation_steps: int,
    initial_epoch: int,
    num_epoch: int,
    running_loss: float,
    total_batches: int,
    total_training_examples: int,
    max_acc: float,
    hide_map_prob: float,
    dropout_images_prob: List[float],
    fp16: bool = True,
    save_checkpoints: bool = True,
    save_every: int = 20,
    save_best: bool = True,
):
    """
    Train a model

    Input:
    - model: TEDD1104 model to train
    - optimizer_name: Name of the optimizer to use [SGD, Adam]
    - optimizer: Optimizer (torch.optim)
    - train_dir: Directory where the train files are stored
    - dev_dir: Directory where the development files are stored
    - test_dir: Directory where the test files are stored
    - output_dir: Directory where the model and the checkpoints are going to be saved
    - batch_size: Batch size (Around 10 for 8GB GPU)
    - initial_epoch: Number of previous epochs used to train the model (0 unless the model has been
      restored from checkpoint)
    - num_epochs: Number of epochs to do
    - max_acc: Accuracy in the development set (0 unless the model has been
      restored from checkpoint)
    - hide_map_prob: Probability for removing the minimap (put a black square)
       from a training example (0<=hide_map_prob<=1)
    - dropout_images_prob List of 5 floats or None, probability for removing each input image during training
     (black image) from a training example (0<=dropout_images_prob<=1)
    - fp16: Use FP16 for training
    - save_checkpoints: save a checkpoint each epoch (Each checkpoint will rewrite the previous one)
    - save_best: save the model that achieves the higher accuracy in the development set

    Output:
     - float: Accuracy in the development test of the best model
    """

    if not os.path.exists(output_dir):
        print(f"{output_dir} does not exits. We will create it.")
        os.makedirs(output_dir)

    writer: SummaryWriter = SummaryWriter()

    criterion: CrossEntropyLoss = torch.nn.CrossEntropyLoss().to(device)
    model.zero_grad()
    print_message("Training...")
    for epoch in range(num_epoch):
        acc_dev: float = 0.0
        num_batches: int = 0
        step_no: int = 0

        data_loader_train = DataLoader(
            Tedd1104Dataset(
                dataset_dir=train_dir,
                hide_map_prob=hide_map_prob,
                dropout_images_prob=dropout_images_prob,
            ),
            batch_size=batch_size,
            shuffle=True,
            num_workers=os.cpu_count(),
            pin_memory=True,
        )
        start_time: float = time.time()
        step_start_time: float = time.time()
        dataloader_delay: float = 0
        model.train()
        for batch in data_loader_train:

            x = torch.flatten(
                torch.stack(
                    (
                        batch["image1"],
                        batch["image2"],
                        batch["image3"],
                        batch["image4"],
                        batch["image5"],
                    ),
                    dim=1,
                ),
                start_dim=0,
                end_dim=1,
            ).to(device)

            y = batch["y"].to(device)
            dataloader_delay += time.time() - step_start_time

            total_training_examples += len(y)

            if fp16:
                with autocast():
                    outputs = model.forward(x)
                    loss = criterion(outputs, y)
                    loss = loss / accumulation_steps

                running_loss += loss.item()
                scaler.scale(loss).backward()

            else:
                outputs = model.forward(x)
                loss = criterion(outputs, y) / accumulation_steps
                running_loss += loss.item()
                loss.backward()

            if ((step_no + 1) % accumulation_steps == 0) or (
                    step_no + 1 >= len(data_loader_train)
            ):  # If we are in the last bach of the epoch we also want to perform gradient descent
                if fp16:
                    # Gradient clipping
                    scaler.unscale_(optimizer)
                    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

                    scaler.step(optimizer)
                    scaler.update()
                    optimizer.zero_grad()
                else:
                    # Gradient clipping
                    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

                    optimizer.step()
                    optimizer.zero_grad()

                total_batches += 1
                num_batches += 1
                scheduler.step(running_loss / total_batches)

                batch_time = round(time.time() - start_time, 2)
                est: float = batch_time * (math.ceil(
                    len(data_loader_train) / accumulation_steps) - num_batches)
                print_message(
                    f"EPOCH: {initial_epoch + epoch}. "
                    f"{num_batches} of {math.ceil(len(data_loader_train)/accumulation_steps)} batches. "
                    f"Total examples used for training {total_training_examples}. "
                    f"Iteration time: {batch_time} secs. "
                    f"Data Loading bottleneck: {round(dataloader_delay, 2)} secs. "
                    f"Epoch estimated time: "
                    f"{str(datetime.timedelta(seconds=est)).split('.')[0]}")

                print_message(
                    f"Loss: {running_loss / total_batches}. "
                    f"Learning rate {optimizer.state_dict()['param_groups'][0]['lr']}"
                )

                writer.add_scalar("Loss/train", running_loss / total_batches,
                                  total_batches)

                if save_checkpoints and (total_batches + 1) % save_every == 0:
                    print_message("Saving checkpoint...")
                    save_checkpoint(
                        path=os.path.join(output_dir, "checkpoint.pt"),
                        model=model,
                        optimizer_name=optimizer_name,
                        optimizer=optimizer,
                        scheduler=scheduler,
                        running_loss=running_loss,
                        total_batches=total_batches,
                        total_training_examples=total_training_examples,
                        acc_dev=max_acc,
                        epoch=initial_epoch + epoch,
                        fp16=fp16,
                        scaler=None if not fp16 else scaler,
                    )

                dataloader_delay: float = 0
                start_time: float = time.time()

            step_no += 1
            step_start_time = time.time()

        del data_loader_train

        print_message("Dev set evaluation...")

        start_time_eval: float = time.time()

        data_loader_dev = DataLoader(
            Tedd1104Dataset(
                dataset_dir=dev_dir,
                hide_map_prob=0,
                dropout_images_prob=[0, 0, 0, 0, 0],
            ),
            batch_size=batch_size //
            2,  # Use smaller batch size to prevent OOM issues
            shuffle=False,
            num_workers=os.cpu_count() // 2,  # Use less cores to save RAM
            pin_memory=True,
        )

        acc_dev: float = evaluate(
            model=model,
            data_loader=data_loader_dev,
            device=device,
            fp16=fp16,
        )

        del data_loader_dev

        print_message("Test set evaluation...")
        data_loader_test = DataLoader(
            Tedd1104Dataset(
                dataset_dir=test_dir,
                hide_map_prob=0,
                dropout_images_prob=[0, 0, 0, 0, 0],
            ),
            batch_size=batch_size //
            2,  # Use smaller batch size to prevent OOM issues
            shuffle=False,
            num_workers=os.cpu_count() // 2,  # Use less cores to save RAM
            pin_memory=True,
        )

        acc_test: float = evaluate(
            model=model,
            data_loader=data_loader_test,
            device=device,
            fp16=fp16,
        )

        del data_loader_test

        print_message(
            f"Acc dev set: {round(acc_dev*100,2)}. "
            f"Acc test set: {round(acc_test*100,2)}.  "
            f"Eval time: {round(time.time() - start_time_eval,2)} secs.")

        if 0.0 < acc_dev > max_acc and save_best:
            max_acc = acc_dev
            print_message(
                f"New max acc in dev set {round(max_acc, 2)}. Saving model...")
            save_model(
                model=model,
                save_dir=output_dir,
                fp16=fp16,
            )

        writer.add_scalar("Accuracy/dev", acc_dev, epoch)
        writer.add_scalar("Accuracy/test", acc_test, epoch)

    return max_acc
Exemplo n.º 15
0
def train(
    model: DRIVEMODEL,
    optimizer_name: str,
    optimizer: torch.optim,
    scheduler: torch.optim.lr_scheduler,
    train_dir: str,
    dev_dir: str,
    test_dir: str,
    output_dir: str,
    batch_size: int,
    accumulation_steps: int,
    initial_epoch: int,
    num_epoch: int,
    max_acc: float,
    hide_map_prob: float,
    dropout_images_prob: List[float],
    num_load_files_training: int,
    fp16: bool = True,
    amp_opt_level=None,
    save_checkpoints: bool = True,
    eval_every: int = 5,
    save_every: int = 20,
    save_best: bool = True,
):
    """
    Train a model

    Input:
    - model: DRIVEMODEL model to train
    - optimizer_name: Name of the optimizer to use [SGD, Adam]
    - optimizer: Optimizer (torch.optim)
    - train_dir: Directory where the train files are stored
    - dev_dir: Directory where the development files are stored
    - test_dir: Directory where the test files are stored
    - output_dir: Directory where the model and the checkpoints are going to be saved
    - batch_size: Batch size (Around 10 for 8GB GPU)
    - initial_epoch: Number of previous epochs used to train the model (0 unless the model has been
      restored from checkpoint)
    - num_epochs: Number of epochs to do
    - max_acc: Accuracy in the development set (0 unless the model has been
      restored from checkpoint)
    - hide_map_prob: Probability for removing the minimap (put a black square)
       from a training example (0<=hide_map_prob<=1)
    - dropout_images_prob List of 5 floats or None, probability for removing each input image during training
     (black image) from a training example (0<=dropout_images_prob<=1)
    - fp16: Use FP16 for training
    - amp_opt_level: If FP16 training Nvidia apex opt level
    - save_checkpoints: save a checkpoint each epoch (Each checkpoint will rewrite the previous one)
    - save_best: save the model that achieves the higher accuracy in the development set

    Output:
     - float: Accuracy in the development test of the best model
    """
    writer: SummaryWriter = SummaryWriter()

    if fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )

    criterion: CrossEntropyLoss = torch.nn.CrossEntropyLoss()
    print("Loading dev set")
    X_dev, y_dev = load_dataset(dev_dir, fp=16 if fp16 else 32)
    X_dev = torch.from_numpy(X_dev)
    print("Loading test set")
    X_test, y_test = load_dataset(test_dir, fp=16 if fp16 else 32)
    X_test = torch.from_numpy(X_test)
    total_training_exampels: int = 0
    model.zero_grad()

    printTrace("Training...")
    for epoch in range(num_epoch):
        step_no: int = 0
        iteration_no: int = 0
        num_used_files: int = 0
        data_loader = DataLoader_AutoDrive(
            dataset_dir=train_dir,
            nfiles2load=num_load_files_training,
            hide_map_prob=hide_map_prob,
            dropout_images_prob=dropout_images_prob,
            fp=16 if fp16 else 32,
        )

        data = data_loader.get_next()
        # Get files in batches, all files will be loaded and data will be shuffled
        while data:
            X, y = data
            model.train()
            start_time: float = time.time()
            total_training_exampels += len(y)
            running_loss: float = 0.0
            num_batchs: int = 0
            acc_dev: float = 0.0

            for X_bacth, y_batch in nn_batchs(X, y, batch_size):
                X_bacth, y_batch = (
                    torch.from_numpy(X_bacth).to(device),
                    torch.from_numpy(y_batch).long().to(device),
                )

                outputs = model.forward(X_bacth)
                loss = criterion(outputs, y_batch) / accumulation_steps
                running_loss += loss.item()

                if fp16:
                    with amp.scale_loss(loss, optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()

                if fp16:
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), 1.0)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

                if (step_no + 1) % accumulation_steps or (
                        num_used_files + 1 >
                        len(data_loader) - num_load_files_training
                        and num_batchs == math.ceil(len(y) / batch_size) - 1
                ):  # If we are in the last bach of the epoch we also want to perform gradient descent
                    optimizer.step()
                    model.zero_grad()

                num_batchs += 1
                step_no += 1

            num_used_files += num_load_files_training

            # Print Statistics
            printTrace(
                f"EPOCH: {initial_epoch+epoch}. Iteration {iteration_no}. "
                f"{num_used_files} of {len(data_loader)} files. "
                f"Total examples used for training {total_training_exampels}. "
                f"Iteration time: {round(time.time() - start_time,2)} secs.")
            printTrace(
                f"Loss: {-1 if num_batchs == 0 else running_loss / num_batchs}. "
                f"Learning rate {optimizer.state_dict()['param_groups'][0]['lr']}"
            )
            writer.add_scalar("Loss/train", running_loss / num_batchs,
                              iteration_no)

            scheduler.step(running_loss / num_batchs)

            if (iteration_no + 1) % eval_every == 0:
                start_time_eval: float = time.time()
                if len(X) > 0 and len(y) > 0:
                    acc_train: float = evaluate(
                        model=model,
                        X=torch.from_numpy(X),
                        golds=y,
                        device=device,
                        batch_size=batch_size,
                    )
                else:
                    acc_train = -1.0

                acc_dev: float = evaluate(
                    model=model,
                    X=X_dev,
                    golds=y_dev,
                    device=device,
                    batch_size=batch_size,
                )

                acc_test: float = evaluate(
                    model=model,
                    X=X_test,
                    golds=y_test,
                    device=device,
                    batch_size=batch_size,
                )

                printTrace(
                    f"Acc training set: {round(acc_train,2)}. "
                    f"Acc dev set: {round(acc_dev,2)}. "
                    f"Acc test set: {round(acc_test,2)}.  "
                    f"Eval time: {round(time.time() - start_time_eval,2)} secs."
                )

                if 0.0 < acc_dev > max_acc and save_best:
                    max_acc = acc_dev
                    printTrace(
                        f"New max acc in dev set {round(max_acc,2)}. Saving model..."
                    )
                    save_model(
                        model=model,
                        save_dir=output_dir,
                        fp16=fp16,
                        amp_opt_level=amp_opt_level,
                    )
                if acc_train > -1:
                    writer.add_scalar("Accuracy/train", acc_train,
                                      iteration_no)
                writer.add_scalar("Accuracy/dev", acc_dev, iteration_no)
                writer.add_scalar("Accuracy/test", acc_test, iteration_no)

            if save_checkpoints and (iteration_no + 1) % save_every == 0:
                printTrace("Saving checkpoint...")
                save_checkpoint(
                    path=os.path.join(output_dir, "checkpoint.pt"),
                    model=model,
                    optimizer_name=optimizer_name,
                    optimizer=optimizer,
                    scheduler=scheduler,
                    acc_dev=acc_dev,
                    epoch=initial_epoch + epoch,
                    fp16=fp16,
                    opt_level=amp_opt_level,
                )

            iteration_no += 1
            data = data_loader.get_next()

        data_loader.close()

    return max_acc
Exemplo n.º 16
0
def model_train \
    ( data_trn:torch.utils.data.Dataset
    , modl:torch.nn.Module
    , crit:torch.nn
    , optm:torch.optim
    , batch_size:int=100
    , hidden_shapes:list=[20,30,40]
    , hidden_acti:str="relu"
    , final_shape:int=1
    , final_acti:str="sigmoid"
    , device:torch.device=get_device()
    , scheduler:torch.optim.lr_scheduler=None
    ):

    # Set to train
    modl.train()
    loss_trn = 0.0
    accu_trn = 0.0

    # Set data generator
    load_trn = DataLoader(data_trn, batch_size=batch_size, shuffle=True, num_workers=0)

    # Loop over each batch
    for batch, data in enumerate(load_trn):
        
        # Extract data
        inputs, labels = data

        # Push data to device
        # inputs, labels = inputs.to(device), labels.to(device)
        inputs.to(device)
        labels.to(device)

        # Zero out the parameter gradients
        optm.zero_grad()

        # Feed forward
        output = modl \
            ( feat=inputs
            , hidden_shapes=hidden_shapes
            , hidden_acti=hidden_acti
            , final_shape=final_shape
            , final_acti=final_acti
            )

        # Calc loss
        loss = crit(output, labels.unsqueeze(1))

        # Global metrics
        loss_trn += loss.item()
        accu_trn += (output.argmax(1) == labels).sum().item()

        # Feed backward
        loss.backward()

        # Optimise
        optm.step()

    # Adjust scheduler
    if scheduler:
        scheduler.step()
    
    return loss_trn/len(data_trn), accu_trn/len(data_trn)
Exemplo n.º 17
0
def train(
    model: torch.nn.Module,
    criterion: torch.nn.modules.loss._Loss,
    dataloader_train: torch.utils.data.DataLoader,
    dataloader_validation: torch.utils.data.DataLoader,
    optimizer: torch.optim.Optimizer,
    use_scheduler: bool,
    scheduler: torch.optim.lr_scheduler,
    num_epochs: int,
    device,
    file_losses: str,
    saving_frequency: int,
):
    """
    Parameters
    ----------
    model : torch.nn.Module
        Model to train.
    criterion : torch.nn.modules.loss._Loss
        Criterion (Loss) to use during training.
    dataloader_train : torch.utils.data.DataLoader
        Dataloader for training.
    dataloader_validation : torch.utils.data.DataLoader
        Dataloader to validate the model during training after each epoch.
    optimizer : torch.optim.Optimizer
        Optimizer used for training.
    use_scheduler : bool
        If True, uses a  MultiStepLR scheduler to adapt the learning rate during training.
    scheduler : torch.optim.lr_scheduler.MultiStepLR
        Scheduler to use to adapt learning rate during training.
    num_epochs : int
        Number of epochs to train for.
    device :
        Device on which to train (GPU or CPU cuda devices)
    file_losses : str
        Name of the file in which to save the Train and Test losses.
    saving_frequency : int
        Frequency at which to save Train and Test loss on file.

    Returns
    -------
    avg_train_error, avg_validation_error : list of float, list of float
        List of Train errors or losses after each epoch.
        List of Validation errors or losses after each epoch.

    """

    print("Starting training during {} epochs".format(num_epochs))
    avg_train_error = []
    avg_validation_error = []

    for epoch in range(num_epochs):

        # Writing results to file regularly in case of interruption during training.
        if epoch + 1 % saving_frequency == 0:
            with open(file_losses, "w") as f:
                f.write("Epoch {}".format(epoch))
                f.write(str(avg_train_error))
                f.write(str(avg_validation_error))

        model.train()
        train_error = []
        for batch_x, batch_y in dataloader_train:
            batch_x, batch_y = batch_x.to(
                device, dtype=torch.float32), batch_y.to(device,
                                                         dtype=torch.float32)

            # Evaluate the network (forward pass)
            model.zero_grad()
            output = model(batch_x)

            # output is Bx1xHxW and batch_y is BxHxW, squeezing first dimension of output to have same dimension
            loss = criterion(torch.squeeze(output, 1), batch_y)
            train_error.append(loss)

            # Compute the gradient
            loss.backward()

            # Update the parameters of the model with a gradient step
            optimizer.step()

        # Each scheduler step is done after a hole epoch
        # Once milestones epochs are reached the learning rates is decreased.
        if use_scheduler:
            scheduler.step()

        # Test the quality on the whole training set (overestimating the true value)
        avg_train_error.append(sum(train_error).item() / len(train_error))

        # Validate the quality on the validation set
        model.eval()
        accuracies_validation = []
        with torch.no_grad():
            for batch_x_validation, batch_y_validation in dataloader_validation:
                batch_x_validation, batch_y_validation = (
                    batch_x_validation.to(device, dtype=torch.float32),
                    batch_y_validation.to(device, dtype=torch.float32),
                )
                # Evaluate the network (forward pass)
                prediction = model(batch_x_validation)
                accuracies_validation.append(
                    criterion(torch.squeeze(prediction, 1),
                              batch_y_validation))
            avg_validation_error.append(
                sum(accuracies_validation).item() / len(accuracies_validation))

        print(
            "Epoch {} | Train Error: {:.5f}, Validation Error: {:.5f}".format(
                epoch, avg_train_error[-1], avg_validation_error[-1]))

    # Writing final results on the file
    with open(file_losses, "w") as f:
        f.write("Epoch {}".format(epoch))
        f.write(str(avg_train_error))
        f.write(str(avg_validation_error))

    return avg_train_error, avg_validation_error
Exemplo n.º 18
0
def train_person_segmentor(
        model: torch.nn.Module,
        train_loader: torch.utils.data.DataLoader,
        valid_loader: torch.utils.data.DataLoader,
        criterion: callable,
        optimiser: torch.optim.Optimizer,
        *,
        save_model_path: Path,
        learning_rate: Number = 6e-2,
        scheduler: torch.optim.lr_scheduler = None,
        n_epochs: int = 100,
        writer: ImageWriterMixin = MockWriter(),
):
    """

    :param model:
    :type model:
    :param train_loader:
    :type train_loader:
    :param valid_loader:
    :type valid_loader:
    :param criterion:
    :type criterion:
    :param optimiser:
    :type optimiser:
    :param scheduler:
    :type scheduler:
    :param save_model_path:
    :type save_model_path:
    :param n_epochs:
    :type n_epochs:
    :return:
    :rtype:"""
    valid_loss_min = numpy.Inf  # track change in validation loss
    assert n_epochs > 0, n_epochs
    E = tqdm(range(1, n_epochs + 1))
    for epoch_i in E:
        train_loss = 0.0
        valid_loss = 0.0

        with TorchTrainSession(model):
            for data, target in tqdm(train_loader):
                output, *_ = model(data.to(global_torch_device()))
                loss = criterion(output,
                                 target.to(global_torch_device()).float())

                optimiser.zero_grad()
                loss.backward()
                optimiser.step()

                train_loss += loss.cpu().item() * data.size(0)

        with TorchEvalSession(model):
            with torch.no_grad():
                for data, target in tqdm(valid_loader):
                    target = target.float()
                    (
                        output,
                        *_,
                    ) = model(  # forward pass: compute predicted outputs by passing inputs to the model
                        data.to(global_torch_device()))
                    validation_loss = criterion(
                        output, target.to(
                            global_torch_device()))  # calculate the batch loss
                    writer.scalar(
                        "dice_validation",
                        dice_loss(output, target.to(global_torch_device())),
                    )

                    valid_loss += validation_loss.detach().cpu().item(
                    ) * data.size(0)  # update average validation loss
                writer.image("input", data, epoch_i)  # write the last batch
                writer.image("truth", target, epoch_i)  # write the last batch
                writer.image("prediction", torch.sigmoid(output),
                             epoch_i)  # write the last batch

        # calculate average losses
        train_loss = train_loss / len(train_loader.dataset)
        valid_loss = valid_loss / len(valid_loader.dataset)

        # save model if validation loss has decreased
        if valid_loss <= valid_loss_min:
            print(
                f"Validation loss decreased ({valid_loss_min:.6f} --> {valid_loss:.6f}).  Saving model ..."
            )
            torch.save(model.state_dict(), save_model_path)
            valid_loss_min = valid_loss

        if scheduler:
            scheduler.step()
            optimiser, scheduler = reschedule_learning_rate(
                model,
                optimiser,
                epoch_i,
                scheduler,
                starting_learning_rate=learning_rate,
            )

        # print training/validation statistics
        current_lr = next(iter(optimiser.param_groups))["lr"]
        E.set_description(f"Epoch: {epoch_i} "
                          f"Training Loss: {train_loss:.6f} "
                          f"Validation Loss: {valid_loss:.6f} "
                          f"Learning rate: {current_lr:.6f}")
        writer.scalar("training_loss", train_loss)
        writer.scalar("validation_loss", valid_loss)
        writer.scalar("learning_rate", current_lr)

    return model
Exemplo n.º 19
0
def train(
    model: TEDD1104,
    optimizer_name: str,
    optimizer: torch.optim,
    scheduler: torch.optim.lr_scheduler,
    train_dir: str,
    dev_dir: str,
    test_dir: str,
    output_dir: str,
    batch_size: int,
    accumulation_steps: int,
    initial_epoch: int,
    num_epoch: int,
    max_acc: float,
    hide_map_prob: float,
    dropout_images_prob: List[float],
    num_load_files_training: int,
    fp16: bool = True,
    amp_opt_level=None,
    save_checkpoints: bool = True,
    eval_every: int = 5,
    save_every: int = 20,
    save_best: bool = True,
):
    """
    Train a model

    Input:
    - model: TEDD1104 model to train
    - optimizer_name: Name of the optimizer to use [SGD, Adam]
    - optimizer: Optimizer (torch.optim)
    - train_dir: Directory where the train files are stored
    - dev_dir: Directory where the development files are stored
    - test_dir: Directory where the test files are stored
    - output_dir: Directory where the model and the checkpoints are going to be saved
    - batch_size: Batch size (Around 10 for 8GB GPU)
    - initial_epoch: Number of previous epochs used to train the model (0 unless the model has been
      restored from checkpoint)
    - num_epochs: Number of epochs to do
    - max_acc: Accuracy in the development set (0 unless the model has been
      restored from checkpoint)
    - hide_map_prob: Probability for removing the minimap (put a black square)
       from a training example (0<=hide_map_prob<=1)
    - dropout_images_prob List of 5 floats or None, probability for removing each input image during training
     (black image) from a training example (0<=dropout_images_prob<=1)
    - fp16: Use FP16 for training
    - amp_opt_level: If FP16 training Nvidia apex opt level
    - save_checkpoints: save a checkpoint each epoch (Each checkpoint will rewrite the previous one)
    - save_best: save the model that achieves the higher accuracy in the development set

    Output:
     - float: Accuracy in the development test of the best model
    """
    writer: SummaryWriter = SummaryWriter()

    if fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )

    criterion: CrossEntropyLoss = torch.nn.CrossEntropyLoss()
    print("Loading dev set")
    X_dev, y_dev = load_dataset(dev_dir, fp=16 if fp16 else 32)
    X_dev = torch.from_numpy(X_dev)
    print("Loading test set")
    X_test, y_test = load_dataset(test_dir, fp=16 if fp16 else 32)
    X_test = torch.from_numpy(X_test)
    total_training_exampels: int = 0
    model.zero_grad()

    trainLoader = DataLoader(dataset=PickleDataset(train_dir),
                             batch_size=batch_size,
                             shuffle=True,
                             num_workers=8)

    printTrace("Training...")
    iteration_no: int = 0
    for epoch in range(num_epoch):
        #step_no: int = 0
        #num_used_files: int = 0

        model.train()
        start_time: float = time.time()
        running_loss: float = 0.0
        acc_dev: float = 0.0

        for num_batchs, inputs in enumerate(trainLoader):
            X_bacth = torch.reshape(
                inputs[0], (inputs[0].shape[0] * 5, 3, inputs[0].shape[2],
                            inputs[0].shape[3])).to(device)
            y_batch = torch.reshape(inputs[1],
                                    (inputs[0].shape[0], )).long().to(device)
            #X_bacth, y_batch = (
            #    torch.from_numpy(batch_data).to(device),
            #    torch.from_numpy(inputs[1]).long().to(device),
            #)

            outputs = model.forward(X_bacth)
            #print(outputs.size())
            #print(y_batch.size())
            loss = criterion(outputs, y_batch) / accumulation_steps
            running_loss += loss.item()

            if fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            if fp16:
                torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer),
                                               1.0)
            else:
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

            optimizer.step()
            model.zero_grad()

        scheduler.step(running_loss)

        # Print Statistics
        printTrace(
            f"Loss: {running_loss}. "
            f"Learning rate {optimizer.state_dict()['param_groups'][0]['lr']}")

        writer.add_scalar("Loss/train", running_loss, iteration_no)

        if (iteration_no + 1) % eval_every == 0:
            start_time_eval: float = time.time()

            acc_dev: float = evaluate(
                model=model,
                X=X_dev,
                golds=y_dev,
                device=device,
                batch_size=batch_size,
            )

            acc_test: float = evaluate(
                model=model,
                X=X_test,
                golds=y_test,
                device=device,
                batch_size=batch_size,
            )

            printTrace(
                f"Acc dev set: {round(acc_dev,2)}. "
                f"Acc test set: {round(acc_test,2)}.  "
                f"Eval time: {round(time.time() - start_time_eval,2)} secs.")

            if 0.0 < acc_dev > max_acc and save_best:
                max_acc = acc_dev
                printTrace(
                    f"New max acc in dev set {round(max_acc,2)}. Saving model..."
                )
                save_model(
                    model=model,
                    save_dir=output_dir,
                    fp16=fp16,
                    amp_opt_level=amp_opt_level,
                )
            writer.add_scalar("Accuracy/dev", acc_dev, iteration_no)
            writer.add_scalar("Accuracy/test", acc_test, iteration_no)

        if save_checkpoints and (iteration_no + 1) % save_every == 0:
            printTrace("Saving checkpoint...")
            save_checkpoint(
                path=os.path.join(output_dir, "checkpoint.pt"),
                model=model,
                optimizer_name=optimizer_name,
                optimizer=optimizer,
                scheduler=scheduler,
                acc_dev=acc_dev,
                epoch=initial_epoch + epoch,
                fp16=fp16,
                opt_level=amp_opt_level,
            )

        iteration_no += 1

    return max_acc