示例#1
0
def train(loader: DataLoader, model: torch.nn.Module, criterion,
          optimizer: Optimizer, epoch: int, noise_sd: float):
    """
    Function to do one training epoch
        :param loader:DataLoader: dataloader (train) 
        :param model:torch.nn.Module: the classifer being trained
        :param criterion: the loss function
        :param optimizer:Optimizer: the optimizer used during trainined
        :param epoch:int: the current epoch number (for logging)
        :param noise_sd:float: the std-dev of the Guassian noise perturbation of the input
    """
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    end = time.time()

    # switch to train mode
    model.train()

    for i, (inputs, targets) in enumerate(loader):
        # measure data loading time
        data_time.update(time.time() - end)

        inputs = inputs.cuda()
        targets = targets.cuda()

        # augment inputs with noise
        inputs = inputs + torch.randn_like(inputs, device='cuda') * noise_sd

        # compute output
        outputs = model(inputs)
        loss = criterion(outputs, targets)

        # measure accuracy and record loss
        acc1, acc5 = accuracy(outputs, targets, topk=(1, 5))
        losses.update(loss.item(), inputs.size(0))
        top1.update(acc1.item(), inputs.size(0))
        top5.update(acc5.item(), inputs.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                  'Acc@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                      epoch,
                      i,
                      len(loader),
                      batch_time=batch_time,
                      data_time=data_time,
                      loss=losses,
                      top1=top1,
                      top5=top5))

    return (losses.avg, top1.avg)
示例#2
0
def train(model: nn.Module,
          data: Union[MoleculeDataset, List[MoleculeDataset]],
          loss_func: Callable,
          optimizer: Optimizer,
          scheduler: _LRScheduler,
          args: Namespace,
          n_iter: int = 0,
          logger: logging.Logger = None,
          writer: SummaryWriter = None,
          chunk_names: bool = False,
          val_smiles: List[str] = None,
          test_smiles: List[str] = None) -> int:
    """
    Trains a model for an epoch.

    :param model: Model.
    :param data: A MoleculeDataset (or a list of MoleculeDatasets if using moe).
    :param loss_func: Loss function.
    :param optimizer: An Optimizer.
    :param scheduler: A learning rate scheduler.
    :param args: Arguments.
    :param n_iter: The number of iterations (training examples) trained on so far.
    :param logger: A logger for printing intermediate results.
    :param writer: A tensorboardX SummaryWriter.
    :param chunk_names: Whether to train on the data in chunks. In this case,
    data must be a list of paths to the data chunks.
    :param val_smiles: Validation smiles strings without targets.
    :param test_smiles: Test smiles strings without targets, used for adversarial setting.
    :return: The total number of iterations (training examples) trained on so far.
    """
    debug = logger.debug if logger is not None else print

    model.train()

    if args.dataset_type == 'bert_pretraining':
        features_loss = nn.MSELoss()

    if chunk_names:
        for path, memo_path in tqdm(data, total=len(data)):
            featurization.SMILES_TO_FEATURES = dict()
            if os.path.isfile(memo_path):
                found_memo = True
                with open(memo_path, 'rb') as f:
                    featurization.SMILES_TO_FEATURES = pickle.load(f)
            else:
                found_memo = False
            with open(path, 'rb') as f:
                chunk = pickle.load(f)
            if args.moe:
                for source in chunk:
                    source.shuffle()
            else:
                chunk.shuffle()
            n_iter = train(model=model,
                           data=chunk,
                           loss_func=loss_func,
                           optimizer=optimizer,
                           scheduler=scheduler,
                           args=args,
                           n_iter=n_iter,
                           logger=logger,
                           writer=writer,
                           chunk_names=False,
                           val_smiles=val_smiles,
                           test_smiles=test_smiles)
            if not found_memo:
                with open(memo_path, 'wb') as f:
                    pickle.dump(featurization.SMILES_TO_GRAPH,
                                f,
                                protocol=pickle.HIGHEST_PROTOCOL)
        return n_iter

    if not args.moe:
        data.shuffle()

    loss_sum, iter_count = 0, 0
    if args.adversarial:
        if args.moe:
            train_smiles = []
            for d in data:
                train_smiles += d.smiles()
        else:
            train_smiles = data.smiles()
        train_val_smiles = train_smiles + val_smiles
        d_loss_sum, g_loss_sum, gp_norm_sum = 0, 0, 0

    if args.moe:
        test_smiles = list(test_smiles)
        random.shuffle(test_smiles)
        train_smiles = []
        for d in data:
            d.shuffle()
            train_smiles.append(d.smiles())
        num_iters = min(len(test_smiles), min([len(d) for d in data]))
    elif args.maml:
        num_iters = args.maml_batches_per_epoch * args.maml_batch_size
        model.zero_grad()
        maml_sum_loss = 0
    else:
        num_iters = len(data) if args.last_batch else len(
            data) // args.batch_size * args.batch_size

    if args.parallel_featurization:
        batch_queue = Queue(args.batch_queue_max_size)
        exit_queue = Queue(1)
        batch_process = Process(target=async_mol2graph,
                                args=(batch_queue, data, args, num_iters,
                                      args.batch_size, exit_queue,
                                      args.last_batch))
        batch_process.start()
        currently_loaded_batches = []

    iter_size = 1 if args.maml else args.batch_size

    for i in trange(0, num_iters, iter_size):
        if args.moe:
            if not args.batch_domain_encs:
                model.compute_domain_encs(
                    train_smiles)  # want to recompute every batch
            mol_batch = [
                MoleculeDataset(d[i:i + args.batch_size]) for d in data
            ]
            train_batch, train_targets = [], []
            for b in mol_batch:
                tb, tt = b.smiles(), b.targets()
                train_batch.append(tb)
                train_targets.append(tt)
            test_batch = test_smiles[i:i + args.batch_size]
            loss = model.compute_loss(train_batch, train_targets, test_batch)
            model.zero_grad()

            loss_sum += loss.item()
            iter_count += len(mol_batch)
        elif args.maml:
            task_train_data, task_test_data, task_idx = data.sample_maml_task(
                args)
            mol_batch = task_test_data
            smiles_batch, features_batch, target_batch = task_train_data.smiles(
            ), task_train_data.features(), task_train_data.targets(task_idx)
            # no mask since we only picked data points that have the desired target
            targets = torch.Tensor(target_batch).unsqueeze(1)
            if next(model.parameters()).is_cuda:
                targets = targets.cuda()
            preds = model(smiles_batch, features_batch)
            loss = loss_func(preds, targets)
            loss = loss.sum() / len(smiles_batch)
            grad = torch.autograd.grad(
                loss, [p for p in model.parameters() if p.requires_grad])
            theta = [
                p for p in model.named_parameters() if p[1].requires_grad
            ]  # comes in same order as grad
            theta_prime = {
                p[0]: p[1] - args.maml_lr * grad[i]
                for i, p in enumerate(theta)
            }
            for name, nongrad_param in [
                    p for p in model.named_parameters()
                    if not p[1].requires_grad
            ]:
                theta_prime[name] = nongrad_param + torch.zeros(
                    nongrad_param.size()).to(nongrad_param)
        else:
            # Prepare batch
            if args.parallel_featurization:
                if len(currently_loaded_batches) == 0:
                    currently_loaded_batches = batch_queue.get()
                mol_batch, featurized_mol_batch = currently_loaded_batches.pop(
                )
            else:
                if not args.last_batch and i + args.batch_size > len(data):
                    break
                mol_batch = MoleculeDataset(data[i:i + args.batch_size])
            smiles_batch, features_batch, target_batch = mol_batch.smiles(
            ), mol_batch.features(), mol_batch.targets()

            if args.dataset_type == 'bert_pretraining':
                batch = mol2graph(smiles_batch, args)
                mask = mol_batch.mask()
                batch.bert_mask(mask)
                mask = 1 - torch.FloatTensor(mask)  # num_atoms
                features_targets = torch.FloatTensor(
                    target_batch['features']
                ) if target_batch[
                    'features'] is not None else None  # num_molecules x features_size
                targets = torch.FloatTensor(target_batch['vocab'])  # num_atoms
                if args.bert_vocab_func == 'feature_vector':
                    mask = mask.reshape(-1, 1)
                else:
                    targets = targets.long()
            else:
                batch = smiles_batch
                mask = torch.Tensor([[x is not None for x in tb]
                                     for tb in target_batch])
                targets = torch.Tensor([[0 if x is None else x for x in tb]
                                        for tb in target_batch])

            if next(model.parameters()).is_cuda:
                mask, targets = mask.cuda(), targets.cuda()

                if args.dataset_type == 'bert_pretraining' and features_targets is not None:
                    features_targets = features_targets.cuda()

            if args.class_balance:
                class_weights = []
                for task_num in range(data.num_tasks()):
                    class_weights.append(
                        args.class_weights[task_num][targets[:,
                                                             task_num].long()])
                class_weights = torch.stack(
                    class_weights).t()  # num_molecules x num_tasks
            else:
                class_weights = torch.ones(targets.shape)

            if args.cuda:
                class_weights = class_weights.cuda()

            # Run model
            model.zero_grad()
            if args.parallel_featurization:
                previous_graph_input_mode = model.encoder.graph_input
                model.encoder.graph_input = True  # force model to accept already processed input
                preds = model(featurized_mol_batch, features_batch)
                model.encoder.graph_input = previous_graph_input_mode
            else:
                preds = model(batch, features_batch)
            if args.dataset_type == 'regression_with_binning':
                preds = preds.view(targets.size(0), targets.size(1), -1)
                targets = targets.long()
                loss = 0
                for task in range(targets.size(1)):
                    loss += loss_func(
                        preds[:, task, :], targets[:, task]
                    ) * class_weights[:,
                                      task] * mask[:,
                                                   task]  # for some reason cross entropy doesn't support multi target
                loss = loss.sum() / mask.sum()
            else:
                if args.dataset_type == 'unsupervised':
                    targets = targets.long().reshape(-1)

                if args.dataset_type == 'bert_pretraining':
                    features_preds, preds = preds['features'], preds['vocab']

                if args.dataset_type == 'kernel':
                    preds = preds.view(int(preds.size(0) / 2), 2,
                                       preds.size(1))
                    preds = model.kernel_output_layer(preds)

                loss = loss_func(preds, targets) * class_weights * mask
                if args.predict_features_and_task:
                    loss = (loss.sum() + loss[:, :-args.features_size].sum() * (args.task_weight-1)) \
                                / (mask.sum() + mask[:, :-args.features_size].sum() * (args.task_weight-1))
                else:
                    loss = loss.sum() / mask.sum()

                if args.dataset_type == 'bert_pretraining' and features_targets is not None:
                    loss += features_loss(features_preds, features_targets)

            loss_sum += loss.item()
            iter_count += len(mol_batch)

        if args.maml:
            model_prime = build_model(args=args, params=theta_prime)
            smiles_batch, features_batch, target_batch = task_test_data.smiles(
            ), task_test_data.features(), [
                t[task_idx] for t in task_test_data.targets()
            ]
            # no mask since we only picked data points that have the desired target
            targets = torch.Tensor([[t] for t in target_batch])
            if next(model_prime.parameters()).is_cuda:
                targets = targets.cuda()
            model_prime.zero_grad()
            preds = model_prime(smiles_batch, features_batch)
            loss = loss_func(preds, targets)
            loss = loss.sum() / len(smiles_batch)
            loss_sum += loss.item()
            iter_count += len(
                smiles_batch
            )  # TODO check that this makes sense, but it's just for display
            maml_sum_loss += loss
            if i % args.maml_batch_size == args.maml_batch_size - 1:
                maml_sum_loss.backward()
                optimizer.step()
                model.zero_grad()
                maml_sum_loss = 0
        else:
            loss.backward()
            if args.max_grad_norm is not None:
                clip_grad_norm_(model.parameters(), args.max_grad_norm)
            optimizer.step()

        if args.adjust_weight_decay:
            current_pnorm = compute_pnorm(model)
            if current_pnorm < args.pnorm_target:
                for i in range(len(optimizer.param_groups)):
                    optimizer.param_groups[i]['weight_decay'] = max(
                        0, optimizer.param_groups[i]['weight_decay'] -
                        args.adjust_weight_decay_step)
            else:
                for i in range(len(optimizer.param_groups)):
                    optimizer.param_groups[i][
                        'weight_decay'] += args.adjust_weight_decay_step

        if isinstance(scheduler, NoamLR):
            scheduler.step()

        if args.adversarial:
            for _ in range(args.gan_d_per_g):
                train_val_smiles_batch = random.sample(train_val_smiles,
                                                       args.batch_size)
                test_smiles_batch = random.sample(test_smiles, args.batch_size)
                d_loss, gp_norm = model.train_D(train_val_smiles_batch,
                                                test_smiles_batch)
            train_val_smiles_batch = random.sample(train_val_smiles,
                                                   args.batch_size)
            test_smiles_batch = random.sample(test_smiles, args.batch_size)
            g_loss = model.train_G(train_val_smiles_batch, test_smiles_batch)

            # we probably only care about the g_loss honestly
            d_loss_sum += d_loss * args.batch_size
            gp_norm_sum += gp_norm * args.batch_size
            g_loss_sum += g_loss * args.batch_size

        n_iter += len(mol_batch)

        # Log and/or add to tensorboard
        if (n_iter // args.batch_size) % args.log_frequency == 0:
            lrs = scheduler.get_lr()
            pnorm = compute_pnorm(model)
            gnorm = compute_gnorm(model)
            loss_avg = loss_sum / iter_count
            if args.adversarial:
                d_loss_avg, g_loss_avg, gp_norm_avg = d_loss_sum / iter_count, g_loss_sum / iter_count, gp_norm_sum / iter_count
                d_loss_sum, g_loss_sum, gp_norm_sum = 0, 0, 0
            loss_sum, iter_count = 0, 0

            lrs_str = ', '.join(f'lr_{i} = {lr:.4e}'
                                for i, lr in enumerate(lrs))
            debug(
                f'Loss = {loss_avg:.4e}, PNorm = {pnorm:.4f}, GNorm = {gnorm:.4f}, {lrs_str}'
            )
            if args.adversarial:
                debug(
                    f'D Loss = {d_loss_avg:.4e}, G Loss = {g_loss_avg:.4e}, GP Norm = {gp_norm_avg:.4}'
                )

            if writer is not None:
                writer.add_scalar('train_loss', loss_avg, n_iter)
                writer.add_scalar('param_norm', pnorm, n_iter)
                writer.add_scalar('gradient_norm', gnorm, n_iter)
                for i, lr in enumerate(lrs):
                    writer.add_scalar(f'learning_rate_{i}', lr, n_iter)

    if args.parallel_featurization:
        exit_queue.put(
            0)  # dummy var to get the subprocess to know that we're done
        batch_process.join()

    return n_iter
示例#3
0
文件: trainer.py 项目: lzfelix/flare
def train_on_loader(model: nn.Module,
                    train_gen: DataLoader,
                    val_gen: Optional[DataLoader],
                    loss_fn: Any,
                    optimizer: Optimizer,
                    n_epochs: int,
                    batch_first: bool = False,
                    device: Optional[torch.device] = torch.device('cpu'),
                    callbacks: Optional[List[Callback]] = None,
                    before_step=None,
                    verbosity: int = 2) -> ModelHistory:
    """Trains a model using data from a DataLoader.

    # Arguments
        model: The PyTorch model.
        train_gen: A DataLoader containing the training data.
        val_gen: A DataLoader containing the validation data.
        loss_fn: The loss function from which gradients are computed.
            Its expected signature is `loss_fn(model_output, y_true)`.
        optimizer: The optimizer used in the backpropagation step.
        n_epochs: How many passes should be performed over the train_gen.
        batch_first: For sequential data, if True data is expected to have the layout
             `[seq_len, batch_size, *]`, otherwise `[batch_size, seq_len, *]`.
        device:
        callbacks: List of utility callbacks to help training the model.
        verbosity: 0: silent, 1:show epoch progress bar, 2: show batch progress bar.
    # Return
        A ModelHistory object representing the model training history.
    """

    callbacks_container = CallbacksContainer(callbacks or [])
    batch_index = 0 if batch_first else 1

    model_history = ModelHistory(model)

    epoch_iterator = range(1, n_epochs + 1)
    if verbosity == 1:
        epoch_iterator = tqdm.tqdm(epoch_iterator, desc='Epoch')
    elif verbosity == 2:
        callbacks_container.append(ProgressBar(len(train_gen), n_epochs))

    for epoch in epoch_iterator:
        model.train()
        callbacks_container.on_epoch_begin(epoch, model_history)

        epoch_loss = 0
        seen_samples = 0
        training_metrics = defaultdict(int)

        for batch_id, batch_data in enumerate(train_gen):
            callbacks_container.on_batch_begin(batch_id, model_history)

            # even if batch_data = [x, y], batch_features = [x] and batch_y = [y]
            batch_features: list = batch_data[:-1]
            batch_labels = batch_data[-1]

            batch_features = [
                _move_to_device(ft, device) for ft in batch_features
            ]
            batch_labels = batch_labels.to(device)

            optimizer.zero_grad()
            output = model(*batch_features)
            loss = loss_fn(output, batch_labels)
            loss.backward()

            if before_step:
                before_step(model, loss, optimizer)

            optimizer.step()

            # All feature matrices should have the same amount of sample entries,
            # hence we can take any of them to figure out the batch size
            n_samples = batch_features[0].size(batch_index)

            seen_samples += n_samples
            epoch_loss += loss.item()

            # Accumulating metrics and losses for the current epoch
            batch_metrics = model.metric(output, batch_labels)
            for m_name, m_value in batch_metrics.items():
                training_metrics[m_name] += m_value
            training_metrics['loss'] = epoch_loss / (batch_id + 1)

            # Normalizing metrics up to the current batch to display in the progress bar
            model_history.append_batch_data(
                _normalize_metrics(training_metrics, seen_samples))

            callbacks_container.on_batch_end(batch_id, model_history)

        model_history.append_trn_logs(
            _normalize_metrics(training_metrics, seen_samples))

        if val_gen:
            val_logs = evaluate_on_loader(model,
                                          val_gen,
                                          loss_fn,
                                          batch_first,
                                          device,
                                          verbosity=0)

            # Adding the val_ prefix and storing metrics over the entire validation data
            val_logs = {
                'val_' + m_name: m_value
                for m_name, m_value in val_logs.items()
            }
            model_history.append_dev_logs(val_logs)

        callbacks_container.on_epoch_end(epoch, model_history)
        if model_history.should_stop_training():
            break

    model_history.close(n_epochs)
    callbacks_container.on_train_end()

    return model_history
def train_stego(*, stegoanalyser: nn.Module,
                train_iterator: DataBatchIterator,
                val_iterator: DataBatchIterator,
                text_iterator: Iterator,
                n_epoch: int, stegoanalyser_opt: Optimizer,
                callbacks: Sequence[Callable] = None, logger: TBLogger,
                encoder: SigmoidTorchEncoder):
    criterion = F.binary_cross_entropy_with_logits
    callbacks = callbacks or []

    for epoch in tqdm(range(n_epoch)):
        stegoanalyser_losses = []
        with train_iterator as iterator:
            for real_batch, _ in iterator:
                batch_size = len(real_batch)
                labels = np.random.choice([0, 1], (batch_size, 1, 1, 1))
                encoded_images = []
                for image, label in zip(real_batch, labels):
                    if label == 1:
                        msg = bytes_to_bits(next(text_iterator))
                        key = generate_random_key(image.shape[1:], len(msg))
                        image = encoder.encode(transform_encoder(image), msg, key)
                        image = inverse_transform_encoder(image)
                    encoded_images.append(image)

                encoded_images = torch.stack(encoded_images)
                labels = torch.from_numpy(labels).float()
                # train stegoanalyzer
                stegoanalyser_opt.zero_grad()
                stegoanalyser_losses.append(
                    process_batch(encoded_images.detach(), labels, stegoanalyser, criterion))
                stegoanalyser_opt.step()

        with val_iterator as iterator:
            accuracy = []
            for real_batch, _ in iterator:
                batch_size = len(real_batch)

                labels = np.random.choice([0, 1], batch_size)
                encoded_images = []
                for image, label in zip(real_batch, labels):
                    if label == 1:
                        msg = bytes_to_bits(next(text_iterator))
                        key = generate_random_key(image.shape[1:], len(msg))
                        image = encoder.encode(transform_encoder(image), msg, key)
                        image = inverse_transform_encoder(image)
                    encoded_images.append(image)

                encoded_images = torch.stack(encoded_images)
                # evaluate stegoanalyzer
                out = inference_step(encoded_images, stegoanalyser).cpu().detach()
                out = torch.sigmoid(out) > 0.5
                out = out.reshape(len(encoded_images)).numpy()
                accuracy_score = sklearn.metrics.accuracy_score(labels, out)
                accuracy.append(accuracy_score)

            mean_accuracy = np.mean(accuracy)
            print(f'validation accuracy score {mean_accuracy}')

            losses = {'Stegoanalyser loss': np.mean(stegoanalyser_losses),
                      'Val accuracy': mean_accuracy}
            logger.policies(losses, epoch)

            # run callbacks
            for callback in callbacks:
                callback(epoch)
示例#5
0
def torch_single_train(model: PyTorchForecast,
                       opt: optim.Optimizer,
                       criterion: Type[torch.nn.modules.loss._Loss],
                       data_loader: DataLoader,
                       takes_target: bool,
                       meta_data_model: PyTorchForecast,
                       meta_data_model_representation: torch.Tensor,
                       meta_loss=None,
                       multi_targets=1,
                       forward_params: Dict = {}) -> float:
    probablistic = None
    if "probabilistic" in model.params["model_params"]:
        probablistic = True
    print('running torch_single_train')
    i = 0
    output_std = None
    running_loss = 0.0
    for src, trg in data_loader:
        opt.zero_grad()
        # Convert to CPU/GPU/TPU
        src = src.to(model.device)
        trg = trg.to(model.device)
        if meta_data_model:
            representation = meta_data_model.model.generate_representation(
                meta_data_model_representation)
            forward_params["meta_data"] = representation
            if meta_loss:
                output = meta_data_model.model(meta_data_model_representation)
                met_loss = compute_loss(meta_data_model_representation, output,
                                        torch.rand(2, 3, 2), meta_loss, None)
                met_loss.backward()
        if takes_target:
            forward_params["t"] = trg
        output = model.model(src, **forward_params)
        if multi_targets == 1:
            labels = trg[:, :, 0]
        elif multi_targets > 1:
            labels = trg[:, :, 0:multi_targets]
        if probablistic:
            output1 = output
            output = output.mean
            output_std = output1.stddev
        loss = compute_loss(labels,
                            output,
                            src,
                            criterion,
                            None,
                            probablistic,
                            output_std,
                            m=multi_targets)
        if loss > 100:
            print("Warning: high loss detected")
        loss.backward()
        opt.step()
        if torch.isnan(loss) or loss == float('inf'):
            raise ValueError(
                "Error infinite or NaN loss detected. Try normalizing data or performing interpolation"
            )
        running_loss += loss.item()
        i += 1
    print("The running loss is: ")
    print(running_loss)
    print("The number of items in train is: " + str(i))
    total_loss = running_loss / float(i)
    return total_loss
def train(loader: DataLoader,
          model: torch.nn.Module,
          criterion,
          optimizer: Optimizer,
          epoch: int,
          noise_sd: float,
          attacker: Attacker,
          device: torch.device,
          writer=None):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    losses_reg = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    end = time.time()

    # switch to train mode
    model.train()
    requires_grad_(model, True)

    for i, batch in enumerate(loader):
        # measure data loading time
        data_time.update(time.time() - end)

        mini_batches = _chunk_minibatch(batch, args.num_noise_vec)
        for inputs, targets in mini_batches:
            inputs, targets = inputs.to(device), targets.to(device)
            batch_size = inputs.size(0)

            noises = [
                torch.randn_like(inputs, device=device) * noise_sd
                for _ in range(args.num_noise_vec)
            ]

            if args.adv_training:
                requires_grad_(model, False)
                model.eval()
                inputs = attacker.attack(model, inputs, targets, noises=noises)
                model.train()
                requires_grad_(model, True)

            # augment inputs with noise
            inputs_c = torch.cat([inputs + noise for noise in noises], dim=0)
            targets_c = targets.repeat(args.num_noise_vec)

            logits = model(inputs_c)
            loss_xent = criterion(logits, targets_c)

            logits_chunk = torch.chunk(logits, args.num_noise_vec, dim=0)
            loss_con = consistency_loss(logits_chunk, args.lbd, args.eta)

            loss = loss_xent + loss_con

            acc1, acc5 = accuracy(logits, targets_c, topk=(1, 5))
            losses.update(loss_xent.item(), batch_size)
            losses_reg.update(loss_con.item(), batch_size)
            top1.update(acc1.item(), batch_size)
            top5.update(acc5.item(), batch_size)

            # compute gradient and do SGD step
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.avg:.3f}\t'
                  'Data {data_time.avg:.3f}\t'
                  'Loss {loss.avg:.4f}\t'
                  'Acc@1 {top1.avg:.3f}\t'
                  'Acc@5 {top5.avg:.3f}'.format(epoch,
                                                i,
                                                len(loader),
                                                batch_time=batch_time,
                                                data_time=data_time,
                                                loss=losses,
                                                top1=top1,
                                                top5=top5))

    writer.add_scalar('loss/train', losses.avg, epoch)
    writer.add_scalar('loss/consistency', losses_reg.avg, epoch)
    writer.add_scalar('batch_time', batch_time.avg, epoch)
    writer.add_scalar('accuracy/train@1', top1.avg, epoch)
    writer.add_scalar('accuracy/train@5', top5.avg, epoch)

    return (losses.avg, top1.avg)
def train_model(
    train_dl: data.DataLoader,
    dev_dl: data.DataLoader,
    model: nn.Module,
    optimizer: optim.Optimizer,
    lr_scheduler: optim.lr_scheduler._LRScheduler,
    args: argparse.Namespace,
) -> nn.Module:

    device = model_utils.get_device()
    # loss_fn = nn.functional.binary_cross_entropy
    loss_fn = model_utils.l1_norm_loss
    val_loss_fn = model_utils.l1_norm_loss
    best_val_loss = torch.tensor(float('inf'))
    saved_checkpoints = []
    writer = SummaryWriter(log_dir=f'{args.log_dir}/{args.experiment}')
    scalar_rand = torch.distributions.uniform.Uniform(0.5, 1.5)

    for e in range(1, args.train_epochs + 1):
        print(f'Training epoch {e}...')

        # Training portion
        torch.cuda.empty_cache()
        with tqdm(total=args.train_batch_size * len(train_dl)) as progress_bar:
            model.train()
            for i, (x_batch, y_batch_biden, y_batch_trump,
                    _) in enumerate(train_dl):
                # trump_scale = scalar_rand.sample()
                # biden_scale = scalar_rand.sample()
                # y_batch_biden = y_batch_biden * biden_scale
                # y_batch_trump = y_batch_trump * trump_scale
                # x_batch = (y_batch_trump + y_batch_biden).abs().to(device)
                x_batch = x_batch.abs().to(device)
                y_batch_biden = y_batch_biden.abs().to(device)
                y_batch_trump = y_batch_trump.abs().to(device)

                # Forward pass on model
                optimizer.zero_grad()
                y_pred_b, y_pred_t = model(x_batch)
                if args.train_trump:
                    # loss = loss_fn(y_pred_t * x_batch, y_batch_trump)
                    loss = loss_fn(y_pred_t, y_batch_trump)
                else:
                    # loss = loss_fn(y_pred_b * x_batch, y_batch_biden)
                    loss = loss_fn(y_pred_b, y_batch_biden)

                # Backward pass and optimization
                loss.backward()
                optimizer.step()
                if args.use_scheduler:
                    lr_scheduler.step(loss)

                progress_bar.update(len(x_batch))
                progress_bar.set_postfix(loss=loss.item())
                writer.add_scalar("train/Loss", loss,
                                  ((e - 1) * len(train_dl) + i) *
                                  args.train_batch_size)

                del x_batch
                del y_batch_biden
                del y_batch_trump
                del y_pred_b
                del y_pred_t
                del loss

        # Validation portion
        torch.cuda.empty_cache()
        with tqdm(total=args.val_batch_size * len(dev_dl)) as progress_bar:
            model.eval()
            val_loss = 0.0
            num_batches_processed = 0
            for i, (x_batch, y_batch_biden, y_batch_trump,
                    _) in enumerate(dev_dl):
                x_batch = x_batch.abs().to(device)
                y_batch_biden = y_batch_biden.abs().to(device)
                y_batch_trump = y_batch_trump.abs().to(device)

                # Forward pass on model
                y_pred_b, y_pred_t = model(x_batch)
                # y_pred_b_mask = torch.ones_like(y_pred_b) * (y_pred_b > args.alpha)
                # y_pred_t_mask = torch.ones_like(y_pred_t) * (y_pred_t > args.alpha)
                y_pred_b_mask = torch.clamp(y_pred_b / x_batch, 0, 1)
                y_pred_t_mask = torch.clamp(y_pred_t / x_batch, 0, 1)

                loss_trump = val_loss_fn(y_pred_t_mask * x_batch,
                                         y_batch_trump)
                loss_biden = val_loss_fn(y_pred_b_mask * x_batch,
                                         y_batch_biden)

                if args.train_trump:
                    val_loss += loss_trump.item()
                else:
                    val_loss += loss_biden.item()
                num_batches_processed += 1

                progress_bar.update(len(x_batch))
                progress_bar.set_postfix(val_loss=val_loss /
                                         num_batches_processed)
                writer.add_scalar("Val/Biden Loss", loss_biden,
                                  ((e - 1) * len(dev_dl) + i) *
                                  args.val_batch_size)
                writer.add_scalar("Val/Trump Loss", loss_trump,
                                  ((e - 1) * len(dev_dl) + i) *
                                  args.val_batch_size)

                del x_batch
                del y_batch_biden
                del y_batch_trump
                del y_pred_b
                del y_pred_t
                del loss_trump
                del loss_biden

            # Save model if it's the best one yet.
            if val_loss / num_batches_processed < best_val_loss:
                best_val_loss = val_loss / num_batches_processed
                filename = f'{args.save_path}/{args.experiment}/{model.__class__.__name__}_best_val.checkpoint'
                model_utils.save_model(model, filename)
                print(f'Model saved!')
                print(f'Best validation loss yet: {best_val_loss}')
            # Save model on checkpoints.
            if e % args.checkpoint_freq == 0:
                filename = f'{args.save_path}/{args.experiment}/{model.__class__.__name__}_epoch_{e}.checkpoint'
                model_utils.save_model(model, filename)
                print(f'Model checkpoint reached!')
                saved_checkpoints.append(filename)
                # Delete checkpoints if there are too many
                while len(saved_checkpoints) > args.num_checkpoints:
                    os.remove(saved_checkpoints.pop(0))

    return model
示例#8
0
def train(
    model: Union[nn.Module, nn.DataParallel],
    train_loader: DataLoader,
    metrics: Dict[str, Metric],
    optimizer: Optimizer,
    scheduler: _LRScheduler,
    device: torch.device,
    epoch: int,
    log_interval: int,
    hooks: Optional[Sequence[Hook]] = None,
    teacher: Optional[Union[nn.Module, nn.DataParallel]] = None,
) -> Dict[str, float]:
    """
    Train a model on some data using some criterion and with some optimizer.

    Args:
        model: Model to train
        train_loader: Data loader for loading training data
        metrics: A dict mapping evaluation metric names to metrics classes
        optimizer: PyTorch optimizer
        scheduler: PyTorch scheduler
        device: PyTorch device object
        epoch: Current epoch, where the first epoch should start at 1
        log_interval: Number of batches before printing loss
        hooks: A sequence of functions that can implement custom behavior
        teacher: teacher network for knowledge distillation, if any

    Returns:
        A dictionary mapping evaluation metric names to computed values for the training set.
    """
    if hooks is None:
        hooks = []

    model.train()
    for metric in metrics.values():
        metric.reset()

    loss_fn = model.module.loss_fn if isinstance(
        model, nn.DataParallel) else model.loss_fn

    seen_examples = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        if teacher is None:
            teacher_output = None
            loss = loss_fn(output, target)  # type: ignore
        else:
            teacher_output = teacher(data)
            loss = loss_fn(output, teacher_output, target)  # type: ignore
        loss.backward()
        optimizer.step()
        project(optimizer)
        scheduler.step()  # type: ignore

        with torch.no_grad():
            for metric in metrics.values():
                metric.update(output, target, teacher_output=teacher_output)

        for hook in hooks:
            hook(
                epoch=epoch,
                global_step=1 + (epoch - 1) * len(train_loader.dataset) +
                batch_idx,
                values_dict={'lr': _get_lr(optimizer)},
                log_interval=log_interval,
            )

        seen_examples += len(data)
        if batch_idx % log_interval == 0:
            logger.info(
                'Train Epoch: {} [{}/{} ({:.0f}%)]\tBatch Loss: {:.6f}'.format(
                    epoch,
                    seen_examples,
                    len(train_loader.dataset),
                    100 * batch_idx / len(train_loader),
                    loss.item(),
                ))

    # Computing evaluation metrics for training set
    computed_metrics = {
        name: metric.compute()
        for name, metric in metrics.items()
    }

    logger.info('Training set evaluation metrics:')
    for name, metric in metrics.items():
        logger.info(f'{name}: {metric}')

    return computed_metrics
示例#9
0
 def _apply_gradient_descent(optimizer: Optimizer) -> None:
     optimizer.step()
示例#10
0
def train(loader: DataLoader, model: nn.Module, criterion: Callable,
          optimizer: Optimizer, num_classes: int, num_super_classes: int,
          maf: torch.FloatTensor, epoch: int,
          args: ArgumentParser) -> torch.FloatTensor:
    batch_time = AverageMeter('Time', ':6.3f')
    losses = AverageMeter('MLM Loss', ':.4e')
    accuracies = AverageMeter('Acc', ':.4e')
    accuracy_deltas = AverageMeter('Acc Delta', ':.4e')
    progress = ProgressMeter(len(loader),
                             [batch_time, losses, accuracies, accuracy_deltas],
                             prefix="Epoch: [{}]".format(epoch))

    model.train()

    device = get_device(args)
    end = time.time()
    for i, (genotypes, labels, super_labels) in enumerate(loader):

        ### Mask for Masked Language Modeling
        mask_num = torch.randint(1, genotypes.shape[1], (1, )).item()
        mask_scores = torch.rand(genotypes.shape[1])
        mask_indices = mask_scores.argsort(descending=True)[:mask_num]
        masked_genotypes = genotypes[:, mask_indices].reshape(-1)
        targets = (masked_genotypes == 1).float().clone().detach()
        genotypes[:, mask_indices] = 0
        maf_vector = maf[labels[0]]

        genotypes = genotypes.to(device)
        masked_genotypes = masked_genotypes.to(device)
        targets = targets.to(device)
        labels = labels.to(device)
        super_labels = super_labels.to(device)
        maf_vector = maf_vector.to(device)

        ### Train
        logits = model(genotypes, labels, super_labels)
        logits = logits[:, mask_indices].reshape(-1)

        # add weight to nonzero maf snps
        weights = torch.ones_like(logits)
        weight_coefficients = (maf_vector[mask_indices] > 0).repeat(
            genotypes.shape[0]).float() * (args.minor_coefficient - 1) + 1
        weights *= weight_coefficients

        loss = criterion(logits, targets, weight=weights, reduction='mean')
        model.zero_grad()
        loss.backward()
        optimizer.step()

        accuracy = (masked_genotypes * logits.sign()).mean() / 2 + .5
        baseline_accuracy = (
            masked_genotypes *
            (maf_vector[mask_indices].repeat(genotypes.shape[0]) -
             .5000001).sign()).mean() / 2 + .5
        accuracy_delta = accuracy - baseline_accuracy

        losses.update(loss.item(), genotypes.shape[0])
        accuracies.update(accuracy.item(), genotypes.shape[0])
        accuracy_deltas.update(accuracy_delta.item(), genotypes.shape[0])
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)
    return losses.avg
示例#11
0
def train(loader: DataLoader,
          model: torch.nn.Module,
          criterion,
          optimizer: Optimizer,
          epoch: int,
          noise_sd: float,
          attacker: Attacker = None):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    end = time.time()

    # switch to train mode
    model.train()
    requires_grad_(model, True)

    for i, batch in enumerate(loader):
        # measure data loading time
        data_time.update(time.time() - end)

        mini_batches = get_minibatches(batch, args.num_noise_vec)
        noisy_inputs_list = []
        for inputs, targets in mini_batches:
            inputs = inputs.cuda()
            targets = targets.cuda()

            inputs = inputs.repeat(
                (1, args.num_noise_vec, 1, 1)).view(batch[0].shape)

            # augment inputs with noise
            noise = torch.randn_like(inputs, device='cuda') * noise_sd

            if args.adv_training:
                requires_grad_(model, False)
                model.eval()
                inputs = attacker.attack(model,
                                         inputs,
                                         targets,
                                         noise=noise,
                                         num_noise_vectors=args.num_noise_vec,
                                         no_grad=args.no_grad_attack)
                model.train()
                requires_grad_(model, True)

            if args.train_multi_noise:
                noisy_inputs = inputs + noise
                targets = targets.unsqueeze(1).repeat(
                    1, args.num_noise_vec).reshape(-1, 1).squeeze()
                outputs = model(noisy_inputs)
                loss = criterion(outputs, targets)

                acc1, acc5 = accuracy(outputs, targets, topk=(1, 5))
                losses.update(loss.item(), noisy_inputs.size(0))
                top1.update(acc1.item(), noisy_inputs.size(0))
                top5.update(acc5.item(), noisy_inputs.size(0))

                # compute gradient and do SGD step
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            else:
                inputs = inputs[::args.num_noise_vec]  # subsample the samples
                noise = noise[::args.num_noise_vec]
                # noise = torch.randn_like(inputs, device='cuda') * noise_sd
                noisy_inputs_list.append(inputs + noise)

        if not args.train_multi_noise:
            noisy_inputs = torch.cat(noisy_inputs_list)
            targets = batch[1].cuda()
            assert len(targets) == len(noisy_inputs)

            outputs = model(noisy_inputs)
            loss = criterion(outputs, targets)

            # measure accuracy and record loss
            acc1, acc5 = accuracy(outputs, targets, topk=(1, 5))
            losses.update(loss.item(), noisy_inputs.size(0))
            top1.update(acc1.item(), noisy_inputs.size(0))
            top5.update(acc5.item(), noisy_inputs.size(0))

            # compute gradient and do SGD step
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                  'Acc@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                      epoch,
                      i,
                      len(loader),
                      batch_time=batch_time,
                      data_time=data_time,
                      loss=losses,
                      top1=top1,
                      top5=top5))

    return (losses.avg, top1.avg)
示例#12
0
def proto_net_episode(model: Module,
                      optimiser: Optimizer,
                      loss_fn: Callable,
                      x: torch.Tensor,
                      y: torch.Tensor,
                      n_shot: int,
                      k_way: int,
                      q_queries: int,
                      distance: str,
                      train: bool):
    """Performs a single training episode for a Prototypical Network.

    # Arguments
        model: Prototypical Network to be trained.
        optimiser: Optimiser to calculate gradient step
        loss_fn: Loss function to calculate between predictions and outputs. Should be cross-entropy
        x: Input samples of few shot classification task
        y: Input labels of few shot classification task
        n_shot: Number of examples per class in the support set
        k_way: Number of classes in the few shot classification task
        q_queries: Number of examples per class in the query set
        distance: Distance metric to use when calculating distance between class prototypes and queries
        train: Whether (True) or not (False) to perform a parameter update

    # Returns
        loss: Loss of the Prototypical Network on this task
        y_pred: Predicted class probabilities for the query set on this task
    """
    if train:
        # Zero gradients
        model.train()
        optimiser.zero_grad()
    else:
        model.eval()

    # Embed all samples
    embeddings = model(x)

    # Samples are ordered by the NShotWrapper class as follows:
    # k lots of n support samples from a particular class
    # k lots of q query samples from those classes

    support = embeddings[:n_shot*k_way] #[n_s X 64]
    queries = embeddings[n_shot*k_way:] #[n_f X 64] 
    
    prototypes = compute_prototypes(support, k_way, n_shot)

    # Calculate squared distances between all queries and all prototypes
    # Output should have shape (q_queries * k_way, k_way) = (num_queries, k_way)
    # distances = pairwise_distances(queries, prototypes, distance)
    distances = pairwise_distances(queries, prototypes, distance)

    # Calculate log p_{phi} (y = k | x)
    log_p_y = (-distances).log_softmax(dim=1)
    loss = loss_fn(log_p_y, y)

    # Prediction probabilities are softmax over distances
    y_pred = (-distances).softmax(dim=1)

    if train:
        # Take gradient step
        loss.backward()
        optimiser.step()
    else:
        pass

    return loss, y_pred
示例#13
0
def train_epoch(
    model: nn.Module,
    dataloader: DataLoader,
    criterion: Callable,
    optimizer: optim.Optimizer,
    device: torch.device,
    train_eval_freq: int = 50,
    clip_grad_norm: float = 1.0,
    verbose: bool = True,
) -> DefaultDict[str, List[float]]:
    """
    Training loop on one epoch.

    :param nn.Module model: PyTorch Neural Network
    :param DataLoader dataloader: PyTorch DataLoader
    :param Callable criterion: PyTorch Critertion
    :param optim.Optimizer optimizer: PyTorch Optimizer
    :param torch.device device: PyTorch Device
    :param int train_eval_freq: evaluation frequency (number of batches) (default: 50)
    :param float clip_grad_norm: max_norm parameter in clip_grad_norm (default: 1.0)
    :param bool verbose: verbose (default: True)
    :return: metrics dict
    :rtype: DefaultDict[str, List[float]]
    """

    metrics: DefaultDict[str, List[float]] = defaultdict(list)

    char2idx = dataloader.dataset.char2idx

    # BOS and EOS
    bos_id = char2idx[BOS]
    eos_id = char2idx[EOS]

    if verbose:
        dataloader = tqdm(dataloader, desc="iter dataloader")

    model.train()

    for i, sentence in enumerate(dataloader):
        sentence = sentence.to(device)

        # lengths and mask
        targets = sentence[:, 1:]  # clip left
        lengths = infer_lengths(sentence, bos_id=bos_id, eos_id=eos_id)
        mask = masking(lengths + 1)  # incl. EOS

        # forward pass
        outputs = model(
            sentence[:, :-1],  # clip right
            lengths + 1,  # incl. BOS
        )
        loss_matrix = criterion(
            input=outputs.transpose(1, 2),
            target=targets,
        )
        loss = (loss_matrix * mask).sum() / mask.sum()

        # backward pass
        loss.backward()

        # clip grad norm
        grad_norm = nn.utils.clip_grad_norm_(
            model.parameters(),
            max_norm=clip_grad_norm,
        )

        # optimizer step
        optimizer.step()
        optimizer.zero_grad()

        # calculate metrics
        metrics["loss"].append(loss.item())
        metrics["grad_norm"].append(grad_norm.item())

        if verbose:
            if i % train_eval_freq == 0:
                generated_sequence = generate(
                    model=model,
                    char2idx=char2idx,
                    prefix="",
                    temperature=0.5,  # hardcoded
                    max_length=100,  # hardcoded
                )
                model.train()  # eval to train

                for metric_name, metric_list in metrics.items():
                    print(
                        f"{metric_name}: {np.mean(metric_list[-train_eval_freq:])}"
                    )
                print(f"inference: {generated_sequence}\n")

    return metrics
def matching_net_episode(model: Module, optimiser: Optimizer, loss_fn: Loss,
                         x: torch.Tensor, y: torch.Tensor, n_shot: int,
                         k_way: int, q_queries: int, distance: str, fce: bool,
                         train: bool):
    """Performs a single training episode for a Matching Network.

    # Arguments
        model: Matching Network to be trained.
        optimiser: Optimiser to calculate gradient step from loss
        loss_fn: Loss function to calculate between predictions and outputs
        x: Input samples of few shot classification task
        y: Input labels of few shot classification task
        n_shot: Number of examples per class in the support set
        k_way: Number of classes in the few shot classification task
        q_queries: Number of examples per class in the query set
        distance: Distance metric to use when calculating distance between support and query set samples
        fce: Whether or not to us fully conditional embeddings
        train: Whether (True) or not (False) to perform a parameter update

    # Returns
        loss: Loss of the Matching Network on this task
        y_pred: Predicted class probabilities for the query set on this task
    """
    if train:
        # Zero gradients
        model.train()
        optimiser.zero_grad()
    else:
        model.eval()

    # Embed all samples
    embeddings = model.encoder(x)

    # Samples are ordered by the NShotWrapper class as follows:
    # k lots of n support samples from a particular class
    # k lots of q query samples from those classes
    support = embeddings[:n_shot * k_way]
    queries = embeddings[n_shot * k_way:]

    # Optionally apply full context embeddings
    if fce:
        # LSTM requires input of shape (seq_len, batch, input_size). `support` is of
        # shape (k_way * n_shot, embedding_dim) and we want the LSTM to treat the
        # support set as a sequence so add a single dimension to transform support set
        # to the shape (k_way * n_shot, 1, embedding_dim) and then remove the batch dimension
        # afterwards

        # Calculate the fully conditional embedding, g, for support set samples as described
        # in appendix A.2 of the paper. g takes the form of a bidirectional LSTM with a
        # skip connection from inputs to outputs
        support, _, _ = model.g(support.unsqueeze(1))
        support = support.squeeze(1)

        # Calculate the fully conditional embedding, f, for the query set samples as described
        # in appendix A.1 of the paper.
        queries = model.f(support, queries)

    # Efficiently calculate distance between all queries and all prototypes
    # Output should have shape (q_queries * k_way, k_way) = (num_queries, k_way)
    distances = pairwise_distances(queries, support, distance)

    # Calculate "attention" as softmax over support-query distances
    attention = (-distances).softmax(dim=1)

    # Calculate predictions as in equation (1) from Matching Networks
    # y_hat = \sum_{i=1}^{k} a(x_hat, x_i) y_i
    y_pred = matching_net_predictions(attention, n_shot, k_way, q_queries)

    # Calculated loss with negative log likelihood
    # Clip predictions for numerical stability
    clipped_y_pred = y_pred.clamp(EPSILON, 1 - EPSILON)
    loss = loss_fn(clipped_y_pred.log(), y)

    if train:
        # Backpropagate gradients
        loss.backward()
        # I found training to be quite unstable so I clip the norm
        # of the gradient to be at most 1
        clip_grad_norm_(model.parameters(), 1)
        # Take gradient step
        optimiser.step()

    return loss, y_pred
示例#15
0
def update_theta(epoch: int, baseline: MovingAverageMetric, entropy_coeff,
                 grad_clip: int, data_loader, device: str,
                 master_pair: MasterPairs, architecture: NASNetwork,
                 optimizer: optim.Optimizer, writer: SummaryWriter,
                 log_frequency: int):
    start = datetime.now()
    policy_loss_metric = AverageMetric()
    accuracy_metric = AccuracyMetric(topk=(1, 5))
    normal_logp_metric = AverageMetric()
    node_normal_entropy_metric = AverageMetric()
    op_normal_entropy_metric = AverageMetric()
    reduced_logp_metric = AverageMetric()
    node_reduced_entropy_metric = AverageMetric()
    op_reduced_entropy_metric = AverageMetric()

    node_normal_entropy_coeff, node_reduced_entropy_coeff, \
        op_normal_entropy_coeff, op_reduced_entropy_coeff = [entropy_coeff, ]*4

    master_pair.unset_force_uniform()

    for iter_, (datas, targets) in enumerate(data_loader, start=1):
        datas, targets = datas.to(device=device), targets.to(device=device)
        (normal_arch, normal_logp, node_normal_entropy, op_normal_entropy), \
            (reduced_arch, reduced_logp, node_reduced_entropy,
             op_reduced_entropy) = master_pair()
        with torch.no_grad():
            outputs = architecture(datas, normal_arch, reduced_arch)
        accuracy_metric.update(targets, outputs)
        accuracy_1 = accuracy_metric.last_accuracy(1).rate
        baseline.update(accuracy_1)
        reward = accuracy_1 - baseline.value
        policy_loss = -(normal_logp + reduced_logp) * reward \
            - (node_normal_entropy*node_normal_entropy_coeff
               + op_normal_entropy*op_normal_entropy_coeff
                + node_reduced_entropy*node_reduced_entropy_coeff
                + op_reduced_entropy*op_reduced_entropy_coeff)

        optimizer.zero_grad()
        policy_loss.backward()
        if grad_clip is not None:
            nn.utils.clip_grad_norm_(master_pair.parameters(), grad_clip)
        optimizer.step()

        # update metrics
        policy_loss_metric.update(policy_loss)
        normal_logp_metric.update(normal_logp)
        node_normal_entropy_metric.update(node_normal_entropy)
        op_normal_entropy_metric.update(op_normal_entropy)
        reduced_logp_metric.update(reduced_logp)
        node_reduced_entropy_metric.update(node_reduced_entropy)
        op_reduced_entropy_metric.update(op_reduced_entropy)

        # iteration log
        if iter_ % log_frequency == 0 or iter_ == len(data_loader):
            message = f"UPDATE THETA, epoch={epoch:03d}, Iter={iter_}/{len(data_loader)}, "
            message += f"reward={reward:.4f}, "
            message += f"pocily loss={policy_loss_metric.last:.4f}({policy_loss_metric.value:.4f}), "
            message += f"moving accuracy={baseline.value*100:.2f}%, "
            message += f"normal_logp={normal_logp_metric.last:.4f}({normal_logp_metric.value:.4f}), "
            message += f"node_normal_entropy={node_normal_entropy_metric.last:.4f}({node_normal_entropy_metric.value:.4f}), "
            message += f"op_normal_entropy={op_normal_entropy_metric.last:.4f}({op_normal_entropy_metric.value:.4f}), "
            message += f"reduced_logp={reduced_logp_metric.last:.4f}({reduced_logp_metric.value:.4f}), "
            message += f"node_reduced_entropy={node_reduced_entropy_metric.last:.4f}({node_reduced_entropy_metric.value:.4f}), "
            message += f"op_reduced_entropy={op_reduced_entropy_metric.last:.4f}({op_reduced_entropy_metric.value:.4f})."
            if iter_ == len(data_loader):
                message += f" Eplased time={datetime.now()-start}."
            utils.logger.info(message)

    writer.add_scalar("update_theta/policy_loss", policy_loss_metric.value,
                      epoch)
    writer.add_scalar("update_theta/baseline", baseline.value, epoch)
    writer.add_scalar("update_theta/accuracy@1",
                      accuracy_metric.accuracy(1).rate, epoch)
    writer.add_scalar("update_theta/accuracy@5",
                      accuracy_metric.accuracy(5).rate, epoch)
    writer.add_scalar("update_theta/normal_logp", normal_logp_metric.value,
                      epoch)
    writer.add_scalar("update_theta/node_normal_entropy",
                      node_normal_entropy_metric.value, epoch)
    writer.add_scalar("update_theta/op_normal_entropy",
                      op_normal_entropy_metric.value, epoch)
    writer.add_scalar("update_theta/reduced_logp", reduced_logp_metric.value,
                      epoch)
    writer.add_scalar("update_theta/node_reduced_entropy",
                      node_reduced_entropy_metric.value, epoch)
    writer.add_scalar("update_theta/op_reduced_entropy",
                      op_reduced_entropy_metric.value, epoch)
示例#16
0
def train(model: MoleculeModel,
          data_loader: MoleculeDataLoader,
          loss_func: Callable,
          optimizer: Optimizer,
          scheduler: _LRScheduler,
          args: TrainArgs,
          n_iter: int = 0,
          logger: logging.Logger = None,
          writer: SummaryWriter = None) -> int:
    """
    Trains a model for an epoch.

    :param model: A :class:`~chemprop.models.model.MoleculeModel`.
    :param data_loader: A :class:`~chemprop.data.data.MoleculeDataLoader`.
    :param loss_func: Loss function.
    :param optimizer: An optimizer.
    :param scheduler: A learning rate scheduler.
    :param args: A :class:`~chemprop.args.TrainArgs` object containing arguments for training the model.
    :param n_iter: The number of iterations (training examples) trained on so far.
    :param logger: A logger for recording output.
    :param writer: A tensorboardX SummaryWriter.
    :return: The total number of iterations (training examples) trained on so far.
    """
    debug = logger.debug if logger is not None else print

    model.train()
    loss_sum, iter_count = 0, 0

    for batch in tqdm(data_loader, total=len(data_loader), leave=False):
        # Prepare batch
        batch: MoleculeDataset
        mol_batch, features_batch, target_batch = batch.batch_graph(
        ), batch.features(), batch.targets()
        mask = torch.Tensor([[x is not None for x in tb]
                             for tb in target_batch])
        targets = torch.Tensor([[0 if x is None else x for x in tb]
                                for tb in target_batch])

        # Run model
        model.zero_grad()
        preds = model(mol_batch, features_batch)

        # Move tensors to correct device
        mask = mask.to(preds.device)
        targets = targets.to(preds.device)
        class_weights = torch.ones(targets.shape, device=preds.device)

        if args.dataset_type == 'multiclass':
            targets = targets.long()
            loss = torch.cat([
                loss_func(preds[:, target_index, :],
                          targets[:, target_index]).unsqueeze(1)
                for target_index in range(preds.size(1))
            ],
                             dim=1) * class_weights * mask
        else:
            loss = loss_func(preds, targets) * class_weights * mask
        loss = loss.sum() / mask.sum()

        loss_sum += loss.item()
        iter_count += len(batch)

        loss.backward()
        if args.grad_clip:
            nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
        optimizer.step()

        if isinstance(scheduler, NoamLR):
            scheduler.step()

        n_iter += len(batch)

        # Log and/or add to tensorboard
        if (n_iter // args.batch_size) % args.log_frequency == 0:
            lrs = scheduler.get_lr()
            pnorm = compute_pnorm(model)
            gnorm = compute_gnorm(model)
            loss_avg = loss_sum / iter_count
            loss_sum, iter_count = 0, 0

            lrs_str = ', '.join(f'lr_{i} = {lr:.4e}'
                                for i, lr in enumerate(lrs))
            debug(
                f'Loss = {loss_avg:.4e}, PNorm = {pnorm:.4f}, GNorm = {gnorm:.4f}, {lrs_str}'
            )

            if writer is not None:
                writer.add_scalar('train_loss', loss_avg, n_iter)
                writer.add_scalar('param_norm', pnorm, n_iter)
                writer.add_scalar('gradient_norm', gnorm, n_iter)
                for i, lr in enumerate(lrs):
                    writer.add_scalar(f'learning_rate_{i}', lr, n_iter)

    return n_iter
示例#17
0
def train(
    model: nn.Module,
    num_epochs: int,
    dataloader: DataLoader,
    optimizer: Optimizer,
    lr_scheduler: Optional[LRScheduler] = None,
    num_gradient_accumulation_steps: Optional[int] = 1,
    max_gradient_norm: Optional[float] = None,
    device: Optional[torch.device] = torch.device('cpu'),
    local_rank: Optional[int] = 0,
    use_distributed: Optional[bool] = False,
    is_master: Optional[bool] = True,
    use_tqdm: Optional[bool] = True,
    logger: Optional[Logger] = None,
) -> None:
    # put model in train mode
    model.train()

    # keep track of the last loss
    last_loss = 0

    for epoch in range(num_epochs):
        # synchronize all processes
        if use_distributed:
            dist.barrier()

        if is_master and logger is not None:
            logger.info(f'Starting with epoch {epoch+1}/{num_epochs}')

        # initialize the progress bar
        if is_master and use_tqdm:
            pbar = tqdm(
                desc=f'Training [epoch {epoch+1}/{num_epochs}]',
                total=len(dataloader),
                unit='batch',
            )

        for step, batch in enumerate(dataloader):
            # unpack batch
            sequences, attention_masks, _, start_positions, end_positions, _, _, _ = batch

            # send sequences, attention_masks, start_positions and end_positions to device
            sequences = sequences.to(device)
            attention_masks = attention_masks.to(device)
            start_positions = start_positions.to(device)
            end_positions = end_positions.to(device)

            # forward pass (loss computation included)
            outputs = model(input_ids=sequences,
                            attention_mask=attention_masks,
                            start_positions=start_positions,
                            end_positions=end_positions)
            loss = outputs[0]
            last_loss = loss.item()

            if use_distributed:
                loss = loss.mean()

            # rescale the loss
            loss /= num_gradient_accumulation_steps

            # backward pass
            loss.backward()

            if step % num_gradient_accumulation_steps == 0:
                # clip the gradient
                if max_gradient_norm is not None:
                    clip_grad_norm_(model.parameters(), max_gradient_norm)

                # update the parameters
                optimizer.step()
                if lr_scheduler is not None:
                    lr_scheduler.step()

                # clear all gradients
                optimizer.zero_grad()

            # update the progress bar
            if is_master and use_tqdm:
                pbar.update()
                pbar.set_postfix({'last_loss': last_loss})

        # close the progress bar
        if is_master and use_tqdm:
            pbar.close()
示例#18
0
def train(model: nn.Module,
          data: Union[MoleculeDataset, List[MoleculeDataset]],
          loss_func: Callable,
          optimizer: Optimizer,
          scheduler: _LRScheduler,
          args: Namespace,
          n_iter: int = 0,
          logger: logging.Logger = None,
          writer: SummaryWriter = None) -> int:
    """
    Trains a model for an epoch.

    :param model: Model.
    :param data: A MoleculeDataset (or a list of MoleculeDatasets if using moe).
    :param loss_func: Loss function.
    :param optimizer: An Optimizer.
    :param scheduler: A learning rate scheduler.
    :param args: Arguments.
    :param n_iter: The number of iterations (training examples) trained on so far.
    :param logger: A logger for printing intermediate results.
    :param writer: A tensorboardX SummaryWriter.
    :return: The total number of iterations (training examples) trained on so far.
    """
    debug = logger.debug if logger is not None else print
    
    model.train()
    
    data.shuffle()

    loss_sum, iter_count = 0, 0

    iter_size = args.batch_size

    if args.class_balance:
        # Reconstruct data so that each batch has equal number of positives and negatives
        # (will leave out a different random sample of negatives each epoch)
        assert len(data[0].targets) == 1  # only works for single class classification
        pos = [d for d in data if d.targets[0] == 1]
        neg = [d for d in data if d.targets[0] == 0]

        new_data = []
        pos_size = iter_size // 2
        pos_index = neg_index = 0
        while True:
            new_pos = pos[pos_index:pos_index + pos_size]
            new_neg = neg[neg_index:neg_index + iter_size - len(new_pos)]

            if len(new_pos) == 0 or len(new_neg) == 0:
                break

            if len(new_pos) + len(new_neg) < iter_size:
                new_pos = pos[pos_index:pos_index + iter_size - len(new_neg)]

            new_data += new_pos + new_neg

            pos_index += len(new_pos)
            neg_index += len(new_neg)

        data = new_data

    num_iters = len(data) // args.batch_size * args.batch_size  # don't use the last batch if it's small, for stability

    for i in trange(0, num_iters, iter_size):
        # Prepare batch
        if i + args.batch_size > len(data):
            break
        mol_batch = MoleculeDataset(data[i:i + args.batch_size])
        smiles_batch, features_batch, target_batch = mol_batch.smiles(), mol_batch.features(), mol_batch.targets()
        batch = smiles_batch
        mask = torch.Tensor([[x is not None for x in tb] for tb in target_batch])
        targets = torch.Tensor([[0 if x is None else x for x in tb] for tb in target_batch])

        if next(model.parameters()).is_cuda:
            mask, targets = mask.cuda(), targets.cuda()

        class_weights = torch.ones(targets.shape)

        if args.cuda:
            class_weights = class_weights.cuda()

        # Run model
        model.zero_grad()
        preds = model(batch, features_batch)

        if args.dataset_type == 'multiclass':
            targets = targets.long()
            loss = torch.cat([loss_func(preds[:, target_index, :], targets[:, target_index]).unsqueeze(1) for target_index in range(preds.size(1))], dim=1) * class_weights * mask
        else:
            loss = loss_func(preds, targets) * class_weights * mask
        loss = loss.sum() / mask.sum()

        loss_sum += loss.item()
        iter_count += len(mol_batch)

        loss.backward()
        optimizer.step()

        if isinstance(scheduler, NoamLR):
            scheduler.step()

        n_iter += len(mol_batch)

        # Log and/or add to tensorboard
        if (n_iter // args.batch_size) % args.log_frequency == 0:
            lrs = scheduler.get_lr()
            pnorm = compute_pnorm(model)
            gnorm = compute_gnorm(model)
            loss_avg = loss_sum / iter_count
            loss_sum, iter_count = 0, 0

            lrs_str = ', '.join(f'lr_{i} = {lr:.4e}' for i, lr in enumerate(lrs))
            debug(f'Loss = {loss_avg:.4e}, PNorm = {pnorm:.4f}, GNorm = {gnorm:.4f}, {lrs_str}')

            if writer is not None:
                writer.add_scalar('train_loss', loss_avg, n_iter)
                writer.add_scalar('param_norm', pnorm, n_iter)
                writer.add_scalar('gradient_norm', gnorm, n_iter)
                for i, lr in enumerate(lrs):
                    writer.add_scalar(f'learning_rate_{i}', lr, n_iter)

    return n_iter
示例#19
0
def train_fn(
    model: nn.Module,
    loader: DataLoader,
    device: str,
    loss_fn: nn.Module,
    optimizer: optim.Optimizer,
    scheduler=None,
    accumulation_steps: int = 1,
    verbose: bool = True,
) -> dict:
    """Train step.

    Args:
        model (nn.Module): model to train
        loader (DataLoader): loader with data
        device (str): device to use for placing batches
        loss_fn (nn.Module): loss function, should be callable
        optimizer (optim.Optimizer): model parameters optimizer
        scheduler ([type], optional): batch scheduler to use.
            Default is `None`.
        accumulation_steps (int, optional): number of steps to accumulate gradients.
            Default is `1`.
        verbose (bool, optional): verbosity mode.
            Default is True.

    Returns:
        dict with metics computed during the training on loader
    """
    model.train()

    metrics = {
        "loss": [],
        "gap": [],
        "accuracy": [],
    }

    with tqdm(total=len(loader), desc="train",
              disable=not verbose) as progress:
        for _idx, batch in enumerate(loader):
            inputs, targets = t2d(batch, device)

            zero_grad(optimizer)

            outputs = model(inputs, targets)
            loss = loss_fn(outputs, targets)

            _loss = loss.detach().item()
            metrics["loss"].append(_loss)

            classes = torch.argmax(outputs, 1)
            _acc = (classes == targets).float().mean().detach().item()
            metrics["accuracy"].append(_acc)

            confidences, predictions = torch.max(outputs, dim=1)
            _gap = gap(predictions, confidences, targets)
            metrics["gap"].append(_gap)

            loss.backward()

            progress.set_postfix_str(
                f"loss {_loss:.4f}, gap {_gap:.4f}, accuracy {_acc:.4f}")

            if (_idx + 1) % accumulation_steps == 0:
                optimizer.step()
                if scheduler is not None:
                    scheduler.step()

            progress.update(1)

            if _idx == DEBUG:
                break

    metrics["loss"] = np.mean(metrics["loss"])
    metrics["gap"] = np.mean(metrics["gap"])
    metrics["accuracy"] = np.mean(metrics["accuracy"])
    return metrics
示例#20
0
def meta_gradient_step(model: Module,
                       optimiser: Optimizer,
                       loss_fn: Callable,
                       x: torch.Tensor,
                       y: torch.Tensor,
                       n_shot: int,
                       k_way: int,
                       q_queries: int,
                       order: int,
                       inner_train_steps: int,
                       inner_lr: float,
                       train: bool,
                       device: Union[str, torch.device]):
    """
    Perform a gradient step on a meta-learner.

    # Arguments
        model: Base model of the meta-learner being trained
        optimiser: Optimiser to calculate gradient step from loss
        loss_fn: Loss function to calculate between predictions and outputs
        x: Input samples for all few shot tasks
        y: Input labels of all few shot tasks
        n_shot: Number of examples per class in the support set of each task
        k_way: Number of classes in the few shot classification task of each task
        q_queries: Number of examples per class in the query set of each task. The query set is used to calculate
            meta-gradients after applying the update to
        order: Whether to use 1st order MAML (update meta-learner weights with gradients of the updated weights on the
            query set) or 2nd order MAML (use 2nd order updates by differentiating through the gradients of the updated
            weights on the query with respect to the original weights).
        inner_train_steps: Number of gradient steps to fit the fast weights during each inner update
        inner_lr: Learning rate used to update the fast weights on the inner update
        train: Whether to update the meta-learner weights at the end of the episode.
        device: Device on which to run computation
    """
    data_shape = x.shape[2:]
    create_graph = (True if order == 2 else False) and train

    task_gradients = []
    task_losses = []
    task_predictions = []
    for meta_batch in x:
        # By construction x is a 5D tensor of shape: (meta_batch_size, n*k + q*k, channels, width, height)
        # Hence when we iterate over the first  dimension we are iterating through the meta batches
        x_task_train = meta_batch[:n_shot * k_way]
        x_task_val = meta_batch[n_shot * k_way:]

        # Create a fast model using the current meta model weights
        fast_weights = OrderedDict(model.named_parameters())

        # Train the model for `inner_train_steps` iterations
        for inner_batch in range(inner_train_steps):
            # Perform update of model weights
            y = create_nshot_task_label(k_way, n_shot).to(device)
            logits = model.functional_forward(x_task_train, fast_weights)
            loss = loss_fn(logits, y)
            gradients = torch.autograd.grad(loss, fast_weights.values(), create_graph=create_graph)

            # Update weights manually
            fast_weights = OrderedDict(
                (name, param - inner_lr * grad)
                for ((name, param), grad) in zip(fast_weights.items(), gradients)
            )

        # Do a pass of the model on the validation data from the current task
        y = create_nshot_task_label(k_way, q_queries).to(device)
        logits = model.functional_forward(x_task_val, fast_weights)
        loss = loss_fn(logits, y)
        loss.backward(retain_graph=True)

        # Get post-update accuracies
        y_pred = logits.softmax(dim=1)
        task_predictions.append(y_pred)

        # Accumulate losses and gradients
        task_losses.append(loss)
        gradients = torch.autograd.grad(loss, fast_weights.values(), create_graph=create_graph)
        named_grads = {name: g for ((name, _), g) in zip(fast_weights.items(), gradients)}
        task_gradients.append(named_grads)

    if order == 1:
        if train:
            sum_task_gradients = {k: torch.stack([grad[k] for grad in task_gradients]).mean(dim=0)
                                  for k in task_gradients[0].keys()}
            hooks = []
            for name, param in model.named_parameters():
                hooks.append(
                    param.register_hook(replace_grad(sum_task_gradients, name))
                )

            model.train()
            optimiser.zero_grad()
            # Dummy pass in order to create `loss` variable
            # Replace dummy gradients with mean task gradients using hooks
            logits = model(torch.zeros((k_way, ) + data_shape).to(device, dtype=torch.double))
            loss = loss_fn(logits, create_nshot_task_label(k_way, 1).to(device))
            loss.backward()
            optimiser.step()

            for h in hooks:
                h.remove()

        return torch.stack(task_losses).mean(), torch.cat(task_predictions)

    elif order == 2:
        model.train()
        optimiser.zero_grad()
        meta_batch_loss = torch.stack(task_losses).mean()

        if train:
            meta_batch_loss.backward()
            optimiser.step()

        return meta_batch_loss, torch.cat(task_predictions)
    else:
        raise ValueError('Order must be either 1 or 2.')
示例#21
0
文件: train.py 项目: ks8/glassML
def train(model: nn.Module,
          data: DataLoader,
          loss_func: Callable,
          optimizer: Optimizer,
          scheduler: NoamLR,
          args: Namespace,
          n_iter: int = 0,
          logger: logging.Logger = None,
          writer: SummaryWriter = None) -> int:
    """
    Trains a model for an epoch.
    :param model: Model.
    :param data: A DataLoader.
    :param loss_func: Loss function.
    :param optimizer: Optimizer.
    :param scheduler: A NoamLR learning rate scheduler.
    :param args: Arguments.
    :param n_iter: The number of iterations (training examples) trained on so far.
    :param logger: A logger for printing intermediate results.
    :param writer: A tensorboardX SummaryWriter.
    :return: The total number of iterations (training examples) trained on so far.
    """
    debug = logger.debug if logger is not None else print

    model.train()

    loss_sum, iter_count = 0, 0
    for batch in tqdm(data, total=len(data)):
        if args.cuda:
            targets = batch.y.float().unsqueeze(1).cuda()
        else:
            targets = batch.y.float().unsqueeze(1)
        batch = GlassBatchMolGraph(
            batch)  # TODO: Apply a check for connectivity of graph

        # Run model
        model.zero_grad()
        preds = model(batch)
        loss = loss_func(preds, targets)
        loss = loss.sum() / loss.size(0)

        loss_sum += loss.item()
        iter_count += len(batch)

        loss.backward()
        if args.max_grad_norm is not None:
            clip_grad_norm_(model.parameters(), args.max_grad_norm)
        optimizer.step()
        scheduler.step()

        n_iter += len(batch)

        # Log and/or add to tensorboard
        if (n_iter // args.batch_size) % args.log_frequency == 0:
            lr = scheduler.get_lr()[0]
            pnorm = compute_pnorm(model)
            gnorm = compute_gnorm(model)
            loss_avg = loss_sum / iter_count
            loss_sum, iter_count = 0, 0

            debug("Loss = {:.4e}, PNorm = {:.4f}, GNorm = {:.4f}, lr = {:.4e}".
                  format(loss_avg, pnorm, gnorm, lr))

            if writer is not None:
                writer.add_scalar('train_loss', loss_avg, n_iter)
                writer.add_scalar('param_norm', pnorm, n_iter)
                writer.add_scalar('gradient_norm', gnorm, n_iter)
                writer.add_scalar('learning_rate', lr, n_iter)

    return n_iter
示例#22
0
    def adv_train(self, epoch: int, optimizer: optim.Optimizer, lr_scheduler: optim.lr_scheduler._LRScheduler = None,
                  validate_interval=10, save=False, verbose=True, indent=0,
                  **kwargs):
        loader_train = self.dataset.loader['train']
        file_path = os.path.join(self.folder_path, self.get_filename() + '.pth')

        _, best_acc = self.validate_fn(verbose=verbose, indent=indent, **kwargs)

        losses = AverageMeter('Loss', ':.4e')
        top1 = AverageMeter('Acc@1', ':6.2f')
        top5 = AverageMeter('Acc@5', ':6.2f')
        params: list[nn.Parameter] = []
        for param_group in optimizer.param_groups:
            params.extend(param_group['params'])
        for _epoch in range(epoch):
            losses.reset()
            top1.reset()
            top5.reset()
            epoch_start = time.perf_counter()
            if verbose and env['tqdm']:
                loader_train = tqdm(loader_train)
            self.model.activate_params(params)
            optimizer.zero_grad()
            for data in loader_train:
                _input, _label = self.model.get_data(data)
                noise = torch.zeros_like(_input)

                def loss_fn(X: torch.FloatTensor):
                    return -self.model.loss(X, _label)
                adv_x = _input
                self.model.train()
                loss = self.model.loss(adv_x, _label)
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()
                for m in range(self.pgd.iteration):
                    self.model.eval()
                    adv_x, _ = self.pgd.optimize(_input=_input, noise=noise, loss_fn=loss_fn, iteration=1)
                    optimizer.zero_grad()
                    self.model.train()
                    loss = self.model.loss(adv_x, _label)
                    loss.backward()
                    optimizer.step()
                optimizer.zero_grad()
                with torch.no_grad():
                    _output = self.model.get_logits(_input)
                acc1, acc5 = self.model.accuracy(_output, _label, topk=(1, 5))
                batch_size = int(_label.size(0))
                losses.update(loss.item(), batch_size)
                top1.update(acc1, batch_size)
                top5.update(acc5, batch_size)
            epoch_time = str(datetime.timedelta(seconds=int(
                time.perf_counter() - epoch_start)))
            self.model.eval()
            self.model.activate_params([])
            if verbose:
                pre_str = '{blue_light}Epoch: {0}{reset}'.format(
                    output_iter(_epoch + 1, epoch), **ansi).ljust(64 if env['color'] else 35)
                _str = ' '.join([
                    f'Loss: {losses.avg:.4f},'.ljust(20),
                    f'Top1 Clean Acc: {top1.avg:.3f}, '.ljust(30),
                    f'Top5 Clean Acc: {top5.avg:.3f},'.ljust(30),
                    f'Time: {epoch_time},'.ljust(20),
                ])
                prints(pre_str, _str, prefix='{upline}{clear_line}'.format(**ansi) if env['tqdm'] else '',
                       indent=indent)
            if lr_scheduler:
                lr_scheduler.step()

            if validate_interval != 0:
                if (_epoch + 1) % validate_interval == 0 or _epoch == epoch - 1:
                    _, cur_acc = self.validate_fn(verbose=verbose, indent=indent, **kwargs)
                    if cur_acc < best_acc:
                        prints('best result update!', indent=indent)
                        prints(f'Current Acc: {cur_acc:.3f}    Previous Best Acc: {best_acc:.3f}', indent=indent)
                        best_acc = cur_acc
                    if save:
                        self.model.save(file_path=file_path, verbose=verbose)
                    if verbose:
                        print('-' * 50)
        self.model.zero_grad()
示例#23
0
文件: utils.py 项目: chomd90/snip
def train(loader: DataLoader,
          model: torch.nn.Module,
          criterion,
          optimizer: Optimizer,
          epoch: int,
          device,
          print_freq,
          display=False):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    end = time.time()

    # switch to train mode
    model.train()

    for i, (inputs, targets) in enumerate(loader):
        # measure data loading time
        data_time.update(time.time() - end)

        inputs = inputs.to(device)
        targets = targets.to(device)

        # compute output
        outputs = model(inputs)
        loss = criterion(outputs, targets)

        # measure accuracy and record loss
        acc1, acc5 = accuracy(outputs, targets, topk=(1, 5))
        losses.update(loss.item(), inputs.size(0))
        top1.update(acc1.item(), inputs.size(0))
        top5.update(acc5.item(), inputs.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % print_freq == 0 and display == True:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                  'Acc@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                      epoch,
                      i,
                      len(loader),
                      batch_time=batch_time,
                      data_time=data_time,
                      loss=losses,
                      top1=top1,
                      top5=top5))

    return (losses.avg, top1.avg)
示例#24
0
def train_controller(max_iter: int,
                     database: DataBase,
                     entropy_coeff: float,
                     grad_clip: int,
                     controller: NASBenchController,
                     nac: NAC,
                     optimizer: optim.Optimizer,
                     writer: tensorboard.SummaryWriter,
                     alternate_train,
                     alternate_evaluate,
                     random_baseline=False,
                     log_frequence: int = 10,
                     search_space=None):
    controller.train()
    nac.eval()
    optimizer.zero_grad()

    policy_loss_avg = MovingAverageMetric()
    entropy_mavg = MovingAverageMetric()
    logp_mavg = MovingAverageMetric()
    score_avg = MovingAverageMetric()

    pseudo_architecture_set = None

    with torch.no_grad():
        *arch_seq, _, _ = controller(force_uniform=True)
        raw_arch = seq2arch_fn(arch_seq)
        baseline_arch = [tensorize_fn(raw_arch, device=device)]

    best_collect_archs = [arch_seq]

    for iter_ in range(max_iter):

        if iter_ % args.n_iteration_update_pseudoset == 0 and args.pseudo_ratio != 0:
            if pseudo_architecture_set is None:
                pseudo_architecture_set = \
                    generate_architecture_with_pseudo_labels(
                        nac, controller,
                        2*int(args.pseudo_ratio*args.train_batch_size),
                        int(args.pseudo_ratio*args.train_batch_size))
            else:
                pseudo_architecture_set = list_concat(
                    pseudo_architecture_set,
                    generate_architecture_with_pseudo_labels(
                        nac, controller, 2 * args.n_sample_architectures,
                        args.n_sample_architectures))

            epoch = args.nac_epochs + iter_
            accuracy, rank_loss = alternate_train(
                epoch=epoch, pseudo_set=pseudo_architecture_set)
            writer.add_scalar("nac/train_accuracy", accuracy, epoch)
            writer.add_scalar("nac/loss", rank_loss, epoch)
            KTau = alternate_evaluate(epoch=epoch)
            writer.add_scalar("nac/ktau", KTau, epoch)

        *arch_seq, logp, entropy = controller()
        with torch.no_grad():
            sample_arch = [tensorize_fn(seq2arch_fn(arch_seq), device=device)]
            score = nac(batchify(sample_arch), batchify(baseline_arch))
            score = score.mean().item()

        policy_loss = -logp * score - entropy_coeff * entropy

        optimizer.zero_grad()
        if grad_clip is not None:
            nn.utils.clip_grad_norm_(controller.parameters(), grad_clip)
        policy_loss.backward()
        optimizer.step()

        policy_loss_avg.update(policy_loss)
        entropy_mavg.update(entropy)
        logp_mavg.update(logp)
        score_avg.update(score)

        if iter_ % log_frequence == 0:
            logger.info(", ".join([
                "Policy Learning",
                f"iter={iter_:03d}",
                f"policy loss={policy_loss_avg.compute():.4f}",
                f"entropy={entropy_mavg.compute():.4f}",
                f"logp={logp_mavg.compute():.4f}",
            ]))
            writer.add_scalar("policy_learning/loss",
                              policy_loss_avg.compute(), iter_)
            writer.add_scalar("policy_learning/entropy",
                              entropy_mavg.compute(), iter_)
            writer.add_scalar("policy_learning/logp", logp_mavg.compute(),
                              iter_)
            writer.add_scalar("policy_learning/reward", score_avg.compute(),
                              iter_)

        if iter_ % args.evaluate_controller_freq == 0:
            baseline_arch, best_collect_archs = derive(iter_, controller, nac,
                                                       10, database, writer,
                                                       best_collect_archs,
                                                       random_baseline,
                                                       search_space)
            torch.save(controller.state_dict(),
                       os.path.join(args.output, f"controller-{iter_}.path"))
示例#25
0
def train(model: nn.Module,
          data_loader: DataLoader,
          optimizer: Optimizer,
          loss_config: DictConfig,
          epoch: int,
          device: str = 'cuda',
          tblogger: Optional[SummaryWriter] = None):
    """ref. https://github.com/JunjH/Revisiting_Single_Depth_Estimation/blob/master/train.py"""

    model.train()

    # func for loss
    cos = nn.CosineSimilarity(dim=1, eps=0)
    get_gradient = Sobel().to(device)

    # init
    batch_time = AverageMeter()
    losses = AverageMeter()
    losses_depth = AverageMeter()
    losses_normal = AverageMeter()
    losses_grad = AverageMeter()
    end = time.time()
    for i, batch in enumerate(data_loader):

        # prepare
        image, depth = batch['image'], batch['depth']
        image = image.to(device)
        depth = depth.to(device)
        optimizer.zero_grad()

        # forward
        output = model(image)

        # loss: depth
        loss_depth = torch.log(torch.abs(output - depth) +
                               loss_config.ALPHA).mean()

        # loss: grad
        depth_grad = get_gradient(depth)
        output_grad = get_gradient(output)
        depth_grad_dx = depth_grad[:, 0, :, :].contiguous().view_as(depth)
        depth_grad_dy = depth_grad[:, 1, :, :].contiguous().view_as(depth)
        output_grad_dx = output_grad[:, 0, :, :].contiguous().view_as(depth)
        output_grad_dy = output_grad[:, 1, :, :].contiguous().view_as(depth)

        loss_dx = torch.log(
            torch.abs(output_grad_dx - depth_grad_dx) +
            loss_config.ALPHA).mean()
        loss_dy = torch.log(
            torch.abs(output_grad_dy - depth_grad_dy) +
            loss_config.ALPHA).mean()

        # loss: normal
        ones = torch.ones(depth.size(0),
                          1,
                          depth.size(2),
                          depth.size(3),
                          requires_grad=True).to(device)
        depth_normal = torch.cat((-depth_grad_dx, -depth_grad_dy, ones), 1)
        output_normal = torch.cat((-output_grad_dx, -output_grad_dy, ones), 1)

        loss_normal = torch.abs(1 - cos(output_normal, depth_normal)).mean()

        # loss
        loss = loss_depth \
            + loss_config.LAMBDA * (loss_dx + loss_dy) \
            + loss_config.MU * loss_normal

        # update
        bs = image.size(0)
        losses.update(loss.item(), bs)
        losses_depth.update(loss_depth.item(), bs)
        losses_normal.update(loss_normal.item(), bs)
        losses_grad.update((loss_dx + loss_dy).item(), bs)

        # step
        loss.backward()
        optimizer.step()

        # time
        batch_time.update(time.time() - end)
        end = time.time()

        # log
        print(f'epoch {epoch}[{i}/{len(data_loader)}], '
              f'time {batch_time.value:.3f} ({batch_time.sum:.3f}), '
              f'loss {losses.value:.4f} ({losses.avg:.4f}), '
              f'l_d {losses_depth.value:.4f} ({losses_depth.avg:.4f}), '
              f'l_g {losses_grad.value:.4f} ({losses_grad.avg:.4f}), '
              f'l_n {losses_normal.value:.4f} ({losses_normal.avg:.4f}), ')

    if tblogger is not None:
        tblogger.add_scalar('train/loss', losses.avg, epoch + 1)
        tblogger.add_scalar('train/l_d', losses_depth.avg, epoch + 1)
        tblogger.add_scalar('train/l_g', losses_grad.avg, epoch + 1)
        tblogger.add_scalar('train/l_n', losses_normal.avg, epoch + 1)
示例#26
0
def train(loader: DataLoader, model: torch.nn.Module, criterion,
          optimizer: Optimizer, epoch: int, args):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    end = time.time()

    # switch to train mode
    model.train()
    m = Bernoulli(torch.tensor([args.calibrated_alpha]).cuda())

    for i, (inputs, targets) in enumerate(loader):
        # measure data loading time
        data_time.update(time.time() - end)

        inputs = inputs.cuda()
        targets = targets.cuda()

        # make MNIST binary
        if args.dataset == 'mnist':
            inputs = (inputs > 0.5).type(torch.cuda.FloatTensor)

        # augment inputs with noise
        if args.perturb == 'bernoulli':
            mask = m.sample(inputs.shape).squeeze(-1)
            # make sure that the value is normalized
            rand_inputs = torch.randint_like(
                inputs, low=0, high=args.K + 1, device='cuda') / float(args.K)
            inputs = inputs * mask + rand_inputs * (1 - mask)
        elif args.perturb == 'gaussian':
            inputs = inputs + torch.randn_like(inputs,
                                               device='cuda') * args.sigma

        # compute output
        outputs = model(inputs)
        loss = criterion(outputs, targets)

        # measure accuracy and record loss
        acc1, acc5 = accuracy(outputs, targets, topk=(1, 5))
        losses.update(loss.item(), inputs.size(0))
        top1.update(acc1.item(), inputs.size(0))
        top5.update(acc5.item(), inputs.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if (i + 1) % args.print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                  'Acc@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                      epoch,
                      i + 1,
                      len(loader),
                      batch_time=batch_time,
                      data_time=data_time,
                      loss=losses,
                      top1=top1,
                      top5=top5))

    return (losses.avg, top1.avg)
示例#27
0
 def run_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int,
                        lambda_closure: Callable, **kwargs):
     optimizer.step(closure=lambda_closure, **kwargs)
示例#28
0
def update_w(epoch: int, data_loader, device: str, master_pair: MasterPairs,
             architecture: NASNetwork, criterion: nn.Module,
             optimizer: optim.Optimizer, force_uniform: bool,
             writer: SummaryWriter, log_frequency: int):
    start = datetime.now()
    loss_metric = AverageMetric()
    accuracy_metric = AccuracyMetric(topk=(1, 5))
    normal_logp_metric = AverageMetric()
    node_normal_entropy_metric = AverageMetric()
    op_normal_entropy_metric = AverageMetric()
    reduced_logp_metric = AverageMetric()
    node_reduced_entropy_metric = AverageMetric()
    op_reduced_entropy_metric = AverageMetric()

    master_pair.set_force_uniform(force_uniform=force_uniform)

    for iter_, (datas, targets) in enumerate(data_loader, start=1):
        datas, targets = datas.to(device=device), targets.to(device=device)
        with torch.no_grad():
            (normal_arch, normal_logp, node_normal_entropy, op_normal_entropy), \
                (reduced_arch, reduced_logp, node_reduced_entropy,
                 op_reduced_entropy) = master_pair()

        outputs = architecture(datas, normal_arch, reduced_arch)
        loss = criterion(outputs, targets)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # update metrics
        loss_metric.update(loss)
        accuracy_metric.update(targets, outputs)
        normal_logp_metric.update(normal_logp)
        node_normal_entropy_metric.update(node_normal_entropy)
        op_normal_entropy_metric.update(op_normal_entropy)
        reduced_logp_metric.update(reduced_logp)
        node_reduced_entropy_metric.update(node_reduced_entropy)
        op_reduced_entropy_metric.update(op_reduced_entropy)

        # iteration log
        if iter_ % log_frequency == 0 or iter_ == len(data_loader):
            message = f"UPDATE W, epoch={epoch:03d}, iter={iter_}/{len(data_loader)}, "
            message += f"celoss={loss_metric.last:.4f}({loss_metric.value:.4f}), "
            message += f"accuracy@1={accuracy_metric.last_accuracy(1).rate*100:.2f}%"
            message += f"({accuracy_metric.accuracy(1).rate*100:.2f}%), "
            message += f"accuracy@5={accuracy_metric.last_accuracy(5).rate*100:.2f}%"
            message += f"({accuracy_metric.accuracy(5).rate*100:.2f}%), "
            message += f"normal_logp={normal_logp_metric.last:.4f}({normal_logp_metric.value:.4f}), "
            message += f"node_normal_entropy={node_normal_entropy_metric.last:.4f}({node_normal_entropy_metric.value:.4f}), "
            message += f"op_normal_entropy={op_normal_entropy_metric.last:.4f}({op_normal_entropy_metric.value:.4f}), "
            message += f"reduced_logp={reduced_logp_metric.last:.4f}({reduced_logp_metric.value:.4f}), "
            message += f"node_reduced_entropy={node_reduced_entropy_metric.last:.4f}({node_reduced_entropy_metric.value:.4f}), "
            message += f"op_reduced_entropy={op_reduced_entropy_metric.last:.4f}({op_reduced_entropy_metric.value:.4f})."
            if iter_ == len(data_loader):
                message += f" Eplased time={datetime.now()-start}."
            utils.logger.info(message)

    writer.add_scalar("update_w/celoss", loss_metric.value, epoch)
    writer.add_scalar("update_w/accuracy@1",
                      accuracy_metric.accuracy(1).rate, epoch)
    writer.add_scalar("update_w/accuracy@5",
                      accuracy_metric.accuracy(5).rate, epoch)
    writer.add_scalar("update_w/normal_logp", normal_logp_metric.value, epoch)
    writer.add_scalar("update_w/node_normal_entropy",
                      node_normal_entropy_metric.value, epoch)
    writer.add_scalar("update_w/op_normal_entropy",
                      op_normal_entropy_metric.value, epoch)
    writer.add_scalar("update_w/reduced_logp", reduced_logp_metric.value,
                      epoch)
    writer.add_scalar("update_w/node_reduced_entropy",
                      node_reduced_entropy_metric.value, epoch)
    writer.add_scalar("update_w/op_reduced_entropy",
                      op_reduced_entropy_metric.value, epoch)
示例#29
0
def train_epoch(
    model: nn.Module,
    dataloader: DataLoader,
    criterion: Callable,
    optimizer: optim.Optimizer,
    device: torch.device,
    clip_grad_norm: float,
    verbose: bool = True,
) -> DefaultDict[str, List[float]]:
    """
    Training loop on one epoch.
    """

    metrics: DefaultDict[str, List[float]] = defaultdict(list)
    idx2label = {v: k for k, v in dataloader.dataset.label2idx.items()}

    if verbose:
        dataloader = tqdm(dataloader)

    model.train()

    for tokens, labels, lengths in dataloader:
        tokens, labels, lengths = (
            tokens.to(device),
            labels.to(device),
            lengths.to(device),
        )

        mask = masking(lengths)

        # forward pass
        logits = model(tokens, lengths)
        loss_without_reduction = criterion(logits.transpose(-1, -2), labels)
        loss = torch.sum(loss_without_reduction * mask) / torch.sum(mask)

        # backward pass
        loss.backward()

        # gradient clipping
        nn.utils.clip_grad_norm_(
            model.parameters(),
            max_norm=clip_grad_norm,
            norm_type=2,
        )

        optimizer.step()
        optimizer.zero_grad()

        # make predictions
        y_true = to_numpy(labels[mask])
        y_pred = to_numpy(logits.argmax(dim=-1)[mask])

        # calculate metrics
        metrics = calculate_metrics(
            metrics=metrics,
            loss=loss.item(),
            y_true=y_true,
            y_pred=y_pred,
            idx2label=idx2label,
        )

    return metrics
示例#30
0
    def forward(self,
                x,
                opt: optim.Optimizer,
                step,
                summary_writer: torch.utils.tensorboard.SummaryWriter = None,
                sample_gpu=None):
        """
        train inside forward
        """
        opt.zero_grad()
        batch_size, num_pts = x.shape[:2]
        z_mu, z_sigma = self.encoder(x)
        # Compute Q(z|X) and entropy H{Q(z|X)}
        if self.use_deterministic_encoder:
            z = z_mu + 0 * z_sigma  # ? why, the original code added this 0 multiplier
            entropy = torch.zeros(batch_size).to(z)
        else:
            z = self.reparametrized_gaussian(z_mu, z_sigma)
            entropy = self.gaussian_entropy(z_sigma)

        # Compute prior P(z)
        if self.use_latent_flow:
            w, dlog_pw = self.latentCNF(z, None,
                                        torch.zeros(batch_size, 1).to(z))
            log_pw = standard_normal_logp(w).view(batch_size,
                                                  -1).sum(dim=1, keepdim=True)
            dlog_pw = dlog_pw.view(batch_size, 1).to(z)
            log_pz = log_pw - dlog_pw
        else:
            log_pz = torch.zeros(batch_size, 1).to(z)

        # Compute recon. P(X|z)
        z_new = z.view(z.shape) + (log_pz * 0.).mean()  # ? why
        y, dlog_py = self.pointCNF(x, z_new,
                                   torch.zeros(batch_size, num_pts, 1).to(x))
        log_py = standard_normal_logp(y).view(batch_size, -1).sum(dim=1,
                                                                  keepdim=True)
        dlog_py = dlog_py.view(batch_size, num_pts, 1).to(x)
        log_px = log_py - dlog_py

        # Loss
        entropy_loss = -entropy.mean() * self.entropy_w
        recon_loss = -log_px.mean() * self.recon_w
        prior_loss = -log_pz.mean() * self.prior_w
        loss = entropy_loss + recon_loss + prior_loss
        loss.backward()
        opt.step()

        # Write logs
        if self.distributed:
            raise NotImplementedError("Distributed training not implemented!")
        else:
            entropy_log = entropy.mean()
            recon_log = -log_px.mean()
            prior_log = -log_pz.mean()

        recon_nats = recon_log / float(x.size(1) * x.size(2))
        prior_nats = prior_log / float(self.fz)

        # reconstruct to save
        with torch.no_grad():
            recon_pc = self.reconstruct(x, truncate_std=True)
            recon_im = visualize(recon_pc,
                                 path='/home/tmp/screenshot.png',
                                 samples=1)

        # sample to save
        if self.use_latent_flow:
            with torch.no_grad():
                sample_pc = self.sample(1, 1024, gpu=sample_gpu)
                sample_im = visualize(sample_pc,
                                      samples=1,
                                      path='/home/tmp/screenshot.png')

        record_dict = {
            'train/entropy':
            entropy_log.cpu().detach().item()
            if not isinstance(entropy_log, float) else entropy_log,
            'train/prior':
            prior_log,
            'train/recon':
            recon_log,
            'train/recon-nats':
            recon_nats,
            'train/prior-nats':
            prior_nats,
            # 'train/sample-reconstructed': recon_pc
        }

        if summary_writer is not None:
            for key, value in record_dict:
                summary_writer.add_scalar(key, value, step)

        record_dict['train/sample-reconstructed'] = recon_im
        summary_writer.add_images('train/sample-reconstructed',
                                  recon_im,
                                  step,
                                  dataformats='NHWC')
        record_dict['train/sample-sampled'] = sample_im
        summary_writer.add_images('train/sample-sampled',
                                  sample_im,
                                  step,
                                  dataformats='NHWC')
        return record_dict