Exemple #1
0
 def __get__(self, instance, obj_type=None):
     if instance is None:
         return self
     with torch.enable_grad():
         value = self.wrapped(instance)
     setattr(instance, self.wrapped.__name__, value)
     return value
def train(args, model, train_dataset,epoch):

    with torch.enable_grad():
        # Turn on training mode which enables dropout.
        model.train()
        total_loss = 0
        start_time = time.time()
        hidden = model.init_hidden(args.batch_size)
        for batch, i in enumerate(range(0, train_dataset.size(0) - 1, args.bptt)):
            inputSeq, targetSeq = get_batch(args,train_dataset, i)
            # inputSeq: [ seq_len * batch_size * feature_size ]
            # targetSeq: [ seq_len * batch_size * feature_size ]

            # Starting each batch, we detach the hidden state from how it was previously produced.
            # If we didn't, the model would try backpropagating all the way to start of the dataset.
            hidden = model.repackage_hidden(hidden)
            hidden_ = model.repackage_hidden(hidden)
            optimizer.zero_grad()

            '''Loss1: Free running loss'''
            outVal = inputSeq[0].unsqueeze(0)
            outVals=[]
            hids1 = []
            for i in range(inputSeq.size(0)):
                outVal, hidden_, hid = model.forward(outVal, hidden_,return_hiddens=True)
                outVals.append(outVal)
                hids1.append(hid)
            outSeq1 = torch.cat(outVals,dim=0)
            hids1 = torch.cat(hids1,dim=0)
            loss1 = criterion(outSeq1.view(args.batch_size,-1), targetSeq.view(args.batch_size,-1))

            '''Loss2: Teacher forcing loss'''
            outSeq2, hidden, hids2 = model.forward(inputSeq, hidden, return_hiddens=True)
            loss2 = criterion(outSeq2.view(args.batch_size, -1), targetSeq.view(args.batch_size, -1))

            '''Loss3: Simplified Professor forcing loss'''
            loss3 = criterion(hids1.view(args.batch_size,-1), hids2.view(args.batch_size,-1).detach())

            '''Total loss = Loss1+Loss2+Loss3'''
            loss = loss1+loss2+loss3
            loss.backward()

            # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
            optimizer.step()

            total_loss += loss.item()

            if batch % args.log_interval == 0 and batch > 0:
                cur_loss = total_loss / args.log_interval
                elapsed = time.time() - start_time
                print('| epoch {:3d} | {:5d}/{:5d} batches | ms/batch {:5.4f} | '
                      'loss {:5.2f} '.format(
                    epoch, batch, len(train_dataset) // args.bptt,
                                  elapsed * 1000 / args.log_interval, cur_loss))
                total_loss = 0
                start_time = time.time()
Exemple #3
0
    def backward(ctx, *args):
        if not torch.autograd._is_checkpoint_valid():
            raise RuntimeError("Checkpointing is not compatible with .grad(), please use .backward() if possible")
        inputs = ctx.saved_tensors
        detached_inputs = detach_variable(inputs)
        with torch.enable_grad():
            outputs = ctx.run_function(*detached_inputs)

        if isinstance(outputs, torch.Tensor):
            outputs = (outputs,)
        torch.autograd.backward(outputs, args)
        return (None,) + tuple(inp.grad for inp in detached_inputs)
Exemple #4
0
    def _epoch(self, net, loader_train, loader_val, y_val, optimizers,
               loss_criterion, pretext, state, scheduler, epoch, epochs,
               databar_disable, reporter, params):
        """
        Helper function to run one epoch of training, essentially the "inner loop" of training.
        """
        import torch
        from .utils import augmentation
        is_train = (optimizers is not None)
        net.train() if is_train else net.eval()
        total_loss, total_correct, total_num = 0.0, 0.0, 0
        data_bar = tqdm(loader_train,
                        disable=databar_disable) if is_train else tqdm(
                            loader_val, disable=databar_disable)

        with (torch.enable_grad() if is_train else torch.no_grad()):
            for data, target in data_bar:
                data, target = pretext.get(data, target)

                if self.device.type == "cuda":
                    data, target = data.cuda(), target.cuda()
                    pretext = pretext.cuda()

                if state in [None, 'finetune']:
                    if self.params['num_augs'] > 0:
                        data, target = augmentation(data, target, **params)
                    out, _ = net(data)
                elif state == 'pretrain':
                    _, out = net(data)
                else:
                    raise NotImplementedError(
                        "state must be one of [None, 'pretrain', 'finetune']")

                loss, correct = pretext(out, target)

                if is_train:
                    for optimizer in optimizers:
                        optimizer.zero_grad()
                    loss.backward()
                    for optimizer in optimizers:
                        optimizer.step()

                total_num += 1
                total_loss += loss.item()

                if epochs == 1:
                    train_test = 'Test'
                else:
                    train_test = 'Train'

                val_metric = None
                if loader_val is not None and state != 'pretrain':
                    val_metric = self.score(X=loader_val,
                                            y=y_val,
                                            metric=self.stopping_metric)
                    data_bar.set_description(
                        '{} Epoch: [{}/{}] Train Loss: {:.4f} Validation {}: {:.2f}'
                        .format(train_test, epoch, epochs,
                                total_loss / total_num,
                                self.stopping_metric.name, val_metric))

                    if reporter is not None:
                        reporter(epoch=epoch + 1,
                                 validation_performance=val_metric,
                                 train_loss=total_loss)

                else:
                    data_bar.set_description(
                        '{} Epoch: [{}/{}] Loss: {:.4f}'.format(
                            train_test, epoch, epochs, total_loss / total_num))

            return total_loss / total_num, val_metric

        if scheduler is not None:
            scheduler.step()
        return total_loss / total_num
Exemple #5
0
    def step(self, closure=None):
        """Performs a single optimization step.

        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            amsgrad = group['amsgrad']

            grads = []
            states = []
            exp_avg = []
            exp_avg_sq = []
            max_exp_avg_sq = []
            params_with_grad = []

            for p in group['params']:
                if p.grad is not None:
                    if p.grad.is_sparse:
                        raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
                    params_with_grad.append(p)
                    grads.append(p.grad)

            for p in params_with_grad:
                state = self.state[p]

                # State initialization
                if len(state) == 0:
                    state['step'] = 0
                    # Exponential moving average of gradient values
                    state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
                    # Exponential moving average of squared gradient values
                    state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
                    if amsgrad:
                        # Maintains max of all exp. moving avg. of sq. grad. values
                        state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)

                exp_avg.append(state['exp_avg'])
                exp_avg_sq.append(state['exp_avg_sq'])

                if amsgrad:
                    max_exp_avg_sq.append(state['max_exp_avg_sq'])

                state['step'] += 1
                states.append(state)

            beta1, beta2 = group['betas']

            bias_correction1 = [1 - beta1 ** state['step'] for state in states] 
            bias_correction2 = [1 - beta2 ** state['step'] for state in states] 
            if group['weight_decay'] != 0:
                grads = torch._foreach_add(grads, params_with_grad, alpha=group['weight_decay'])

            #
            # Decay the first and second moment running average coefficient
            #
            torch._foreach_mul_(exp_avg, beta1)
            torch._foreach_add_(exp_avg, grads, alpha=1 - beta1)

            torch._foreach_mul_(exp_avg_sq, beta2)
            torch._foreach_addcmul_(exp_avg_sq, grads, grads, 1 - beta2)

            if amsgrad:
                # Maintains the maximum of all 2nd moment running avg. till now
                max_exp_avg_sq = [torch.max(a, b) for a, b in zip(max_exp_avg_sq, exp_avg_sq)]
                # Use the max. for normalizing running avg. of gradient
                max_exp_avg_sq_sqrt = torch._foreach_sqrt(max_exp_avg_sq)
                bias_correction_sqrt = [math.sqrt(bc) for bc in bias_correction2]
                torch._foreach_div_scalar_list_(max_exp_avg_sq_sqrt, bias_correction_sqrt)
                denom = torch._foreach_add(max_exp_avg_sq_sqrt, group['eps'])
            else:
                exp_avg_sq_sqrt = torch._foreach_sqrt(exp_avg_sq)
                bias_correction_sqrt = [math.sqrt(bc) for bc in bias_correction2]
                torch._foreach_div_scalar_list_(exp_avg_sq_sqrt, bias_correction_sqrt)
                denom = torch._foreach_add(exp_avg_sq_sqrt, group['eps'])

            step_size = [group['lr'] / bc for bc in bias_correction1]

            for i in range(len(step_size)):
                params_with_grad[i].addcdiv_(exp_avg[i], denom[i], value=-step_size[i])

        return loss
Exemple #6
0
def main(args):
    # Set up logging and devices
    startime = datetime.now()
    args.save_dir = util.get_save_dir(args.save_dir, args.name, training=True)
    log = util.get_logger(args.save_dir, args.name)

    time_log = args.log_time
    if time_log > 0:
        log.info(f'Start training at: {startime.strftime("%H:%M:%S")}')

    tbx = SummaryWriter(args.save_dir)
    device, args.gpu_ids = util.get_available_devices()
    log.info(f'Args: {dumps(vars(args), indent=4, sort_keys=True)}')
    args.batch_size *= max(1, len(args.gpu_ids))
    model_type = args.model

    # Set random seed
    log.info(f'Using random seed {args.seed}...')
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    # check this
    #useCharEmbeddings = args.model == 'BiDAFplus'

    # Get embeddings
    log.info('Loading embeddings...')
    print(f'{args.word_emb_file}')
    word_vectors = util.torch_from_json(args.word_emb_file)
    char_vectors = util.torch_from_json(args.char_emb_file)
    if time_log > 0:
        log.info(f'Loaded embeddings: {(datetime.now()-startime).seconds}')
    # load_char_vectors
    # Get model
    log.info('Building model...')
    if model_type == 'BiDAFplus':  #
        model = BiDAFplus(word_vectors=word_vectors,
                          char_vectors=char_vectors,
                          hidden_size=args.hidden_size,
                          params=get_params(model_type, args.params))

    elif model_type == 'BiDAFbase':
        model = BiDAFbase(word_vectors=word_vectors,
                          hidden_size=args.hidden_size,
                          drop_prob=args.drop_prob)

    elif model_type == "Transformer":
        model = TransformerModel(word_vectors=word_vectors,
                                 char_vectors=char_vectors,
                                 params=get_params(model_type, args.params))

    elif model_type == 'BiDAF':
        model = BiDAF(word_vectors=word_vectors,
                      char_vectors=char_vectors,
                      hidden_size=args.hidden_size,
                      params=get_params(model_type, args.params))

    model = nn.DataParallel(model, args.gpu_ids)
    if time_log > 0:
        log.info(f'Built model: {(datetime.now()-startime).seconds}')
    if args.load_path:
        log.info(f'Loading checkpoint from {args.load_path}...')
        model, step = util.load_model(model, args.load_path, args.gpu_ids)
    else:
        step = 0
    model = model.to(device)
    model.train()
    ema = util.EMA(model, args.ema_decay)

    # Get saver
    saver = util.CheckpointSaver(args.save_dir,
                                 max_checkpoints=args.max_checkpoints,
                                 metric_name=args.metric_name,
                                 maximize_metric=args.maximize_metric,
                                 log=log)

    # Get optimizer and scheduler
    optimizer = optim.Adadelta(model.parameters(),
                               args.lr,
                               weight_decay=args.l2_wd)
    scheduler = sched.LambdaLR(optimizer, lambda s: 1.)  # Constant LR

    # Get data loader
    log.info('Building dataset...')
    if args.mode != 'quick_eval':
        train_dataset = SQuAD(args.train_record_file, args.use_squad_v2)
        train_loader = data.DataLoader(train_dataset,
                                       batch_size=args.batch_size,
                                       shuffle=True,
                                       num_workers=args.num_workers,
                                       collate_fn=collate_fn)

        dev_dataset = SQuAD(args.dev_record_file, args.use_squad_v2)
        dev_loader = data.DataLoader(dev_dataset,
                                     batch_size=args.batch_size,
                                     shuffle=False,
                                     num_workers=args.num_workers,
                                     collate_fn=collate_fn)

    else:
        loaded_data = quick_eval_data_loader()
        train_loader = [loaded_data for _ in range(5)]
        dev_loader = [quick_eval_data_loader(dev=True)]
        train_dataset = train_loader
        dev_dataset = dev_loader

    log.info('Built dataset: {}:{}'.format(*divmod((datetime.now() -
                                                    startime).seconds, 60)))

    # Train
    log.info('Training...')
    steps_till_eval = args.eval_steps
    epoch = step // len(train_dataset)
    if time_log > 0:
        traintime = datetime.now()
    total_iterations = 0
    while epoch != args.num_epochs:
        epoch += 1
        log.info(f'Starting epoch {epoch}...')
        if time_log > 0:
            epochtime = datetime.now()
        if args.mode != 'quick_eval':
            progress_len = len(train_loader.dataset)
        else:
            progress_len = len(train_loader)
        with torch.enable_grad(), \
                tqdm(total=progress_len) as progress_bar:

            for cw_idxs, cc_idxs, qw_idxs, qc_idxs, y1, y2, ids in train_loader:

                #quick_eval_data_saver(cw_idxs, cc_idxs, qw_idxs, qc_idxs, y1, y2, ids)

                #########
                if time_log > 0:
                    itertime = datetime.now()
                # Setup for forward
                cw_idxs = cw_idxs.to(device)
                qw_idxs = qw_idxs.to(device)
                batch_size = cw_idxs.size(0)
                optimizer.zero_grad()

                if model_type == 'BiDAF' or model_type == "Transformer":
                    cc_idxs = cc_idxs.to(device)
                    qc_idxs = qc_idxs.to(device)

                    log_p1, log_p2 = model(cc_idxs, qc_idxs, cw_idxs, qw_idxs)

                # Forward
                elif model_type == 'BiDAFbase':
                    log_p1, log_p2 = model(cw_idxs, qw_idxs)

                y1, y2 = y1.to(device), y2.to(device)
                loss = F.nll_loss(log_p1, y1) + F.nll_loss(log_p2, y2)
                loss_val = loss.item()

                if time_log > 2:
                    forwardtime = datetime.now()
                    log.info('Forward time {}:{}'.format(
                        *divmod((forwardtime - itertime).seconds, 60)))
                # Backward
                loss.backward()
                nn.utils.clip_grad_norm_(model.parameters(),
                                         args.max_grad_norm)
                optimizer.step()
                scheduler.step(step // batch_size)
                ema(model, step // batch_size)

                if time_log > 2:
                    backwardtime = datetime.now()
                    log.info('Backward time {}:{}'.format(
                        *divmod((backwardtime - forwardtime).seconds, 60)))
                # Log info
                step += batch_size
                progress_bar.update(batch_size)
                progress_bar.set_postfix(epoch=epoch, NLL=loss_val)
                tbx.add_scalar('train/NLL', loss_val, step)
                tbx.add_scalar('train/LR', optimizer.param_groups[0]['lr'],
                               step)

                if time_log > 0:
                    enditertime = datetime.now()
                    #log.info('Iteration {} {}:{}'.format(total_iterations,
                    #    *divmod((enditertime-itertime).seconds, 60)))

                steps_till_eval -= batch_size
                if steps_till_eval <= 0 or args.mode == 'quick_eval':
                    steps_till_eval = args.eval_steps

                    # Evaluate and save checkpoint
                    log.info(f'Evaluating at step {step}...')
                    ema.assign(model)
                    results, pred_dict = evaluate(
                        model,
                        dev_loader,
                        device,
                        args.dev_eval_file,
                        args.max_ans_len,
                        args.use_squad_v2,
                        model_type,
                        quick_eval=args.mode == 'quick_eval')
                    saver.save(step, model, results[args.metric_name], device)
                    ema.resume(model)

                    # Log to console
                    if time_log > 1:
                        log.info('Eval time {}:{}'.format(
                            *divmod((datetime.now() -
                                     enditertime).seconds, 60)))

                    results_str = ', '.join(f'{k}: {v:05.2f}'
                                            for k, v in results.items())
                    log.info(f'Dev {results_str}')

                    # Log to TensorBoard
                    log.info('Visualizing in TensorBoard...')
                    for k, v in results.items():
                        tbx.add_scalar(f'dev/{k}', v, step)
                    util.visualize(tbx,
                                   pred_dict=pred_dict,
                                   eval_path=args.dev_eval_file,
                                   step=step,
                                   split='dev',
                                   num_visuals=args.num_visuals)
                total_iterations += 1
                if ((time_log == 2) and (total_iterations % 10 == 0)) or (
                    (time_log == 1) and (total_iterations % 100 == 0)):
                    log.info('Mean iteration time {}:{}'.format(
                        *divmod((enditertime - traintime).seconds /
                                total_iterations, 60)))

        if time_log > 0:
            endepochtime = datetime.now()
            log.info('Epoch time {}:{}'.format(
                *divmod((endepochtime - epochtime).seconds, 60)))
    def train(self, net, samples, optimizer, e):
        alpha = 2 * max(0, ((50 - e) / 50))
        criterion = losses.ELULovaszFocalWithLogitsLoss(alpha, 2 - alpha)

        transforms = generator.TransformationsGenerator([
            random.RandomFlipLr(),
            random.RandomAffine(image_size=101,
                                translation=lambda rs:
                                (rs.randint(-20, 20), rs.randint(-20, 20)),
                                scale=lambda rs: (rs.uniform(0.85, 1.15), 1),
                                **utils.transformations_options)
        ])

        samples_aux = list(
            set(samples).intersection(set(utils.get_aux_samples())))
        dataset_aux = datasets.ImageDataset(samples_aux, settings.train,
                                            transforms)

        dataset_pseudo = datasets.SemiSupervisedImageDataset(
            samples_test,
            settings.test,
            transforms,
            size=len(samples_test),
            test_predictions=self.test_predictions,
            momentum=0.0)

        dataset = datasets.ImageDataset(samples, settings.train, transforms)
        weight_train = len(dataset_pseudo) / len(dataset) * 2
        weight_aux = weight_train / 2
        weights = [weight_train] * len(dataset) + [weight_aux] * len(
            dataset_aux) + [1] * len(dataset_pseudo)
        dataloader = DataLoader(
            ConcatDataset([dataset, dataset_aux, dataset_pseudo]),
            num_workers=10,
            batch_size=16,
            sampler=WeightedRandomSampler(weights=weights, num_samples=3200))

        average_meter_train = meters.AverageMeter()

        with tqdm(total=len(dataloader), leave=False,
                  ascii=True) as pbar, torch.enable_grad():
            net.train()

            padding = tta.Pad((13, 14, 13, 14))

            for images, masks_targets in dataloader:
                masks_targets = masks_targets.to(gpu)
                masks_predictions = padding.transform_backward(
                    net(padding.transform_forward(images))).contiguous()

                loss = criterion(masks_predictions, masks_targets)
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()

                average_meter_train.add('loss', loss.item())
                self.update_pbar(torch.sigmoid(masks_predictions),
                                 masks_targets, pbar, average_meter_train,
                                 'Training epoch {}'.format(e))

        train_stats = {
            'train_' + k: v
            for k, v in average_meter_train.get_all().items()
        }
        return train_stats
Exemple #8
0
    def fit_generator(self,
                      train_generator,
                      valid_generator=None,
                      *,
                      epochs=1000,
                      steps_per_epoch=None,
                      validation_steps=None,
                      initial_epoch=1,
                      verbose=True,
                      callbacks=[]):
        # pylint: disable=too-many-locals, line-too-long
        """
        Trains the model on a dataset using a generator.

        Args:
            train_generator: Generator-like object for the training dataset.
                The generator must yield a tuple ``(x, y)`` where ``x`` is a
                batch of the training dataset and ``y`` is the corresponding
                ground truths. ``y`` should be a Tensor or a Numpy array with
                the first dimension being the batch size since ``len(y)`` is
                taken as the batch size. The loss and the metrics are averaged
                using this batch size. If ``y`` is not a Tensor or a Numpy
                array, then a warning is raised and the "batch size" defaults
                to 1.

                If the generator does not have a method ``__len__()``, either
                the ``steps_per_epoch`` argument must be provided, or the
                iterator returned raises a StopIteration exception at the end
                of the training dataset. PyTorch DataLoaders object do provide a
                ``__len__()`` method.

                Before each epoch, the method ``__iter__()`` on the generator is
                called and the method ``__next__()`` is called for each step on
                resulting object returned by ``__iter__()``. Notice that a call
                to ``__iter__()`` on a generator made using the python keyword
                ``yield`` returns the generator itself.
            valid_generator (optional): Generator-like object for the
                validation dataset. This generator is optional. The generator is
                used the same way as the  generator ``train_generator``. If the
                generator does not have a method ``__len__()``, either the
                ``validation_steps`` or the ``steps_per_epoch`` argument must be
                provided or the iterator returned raises a StopIteration
                exception at the end of the validation dataset.
                (Default value = None)
            epochs (int): Number of times the entire training dataset is seen.
                (Default value = 1000)
            steps_per_epoch (int, optional): Number of batch used during one
                epoch. Obviously, using this argument may cause one epoch not to
                see the entire training dataset or see it multiple times.
                (Defaults the number of steps needed to see the entire
                training dataset)
            validation_steps (int, optional): Same as for ``steps_per_epoch``
                but for the validation dataset. (Defaults to ``steps_per_epoch``
                if provided or the number of steps needed to see the entire
                validation dataset)
            initial_epoch (int, optional): Epoch at which to start training
                (useful for resuming a previous training run).
                (Default value = 1)
            verbose (bool): Whether to display the progress of the training.
                (Default value = True)
            callbacks (list of poutyne.framework.Callback): List of callbacks
                that will be called during training. (Default value = [])

        Returns:
            List of dict containing the history of each epoch.

        Example:
            .. code-block:: python

                model = Model(pytorch_module, optimizer, loss_function)
                history = model.fit_generator(train_generator,
                                              valid_generator,
                                              epochs=num_epochs,
                                              verbose=False)
                print(*history, sep="\\n")

            .. code-block:: python

                {'epoch': 1, 'loss': 1.7198852968215943, 'time': 0.019999928001197986, 'acc': 19.375, 'val_loss': 1.6674459838867188, 'val_acc': 22.0}
                {'epoch': 2, 'loss': 1.7054892110824584, 'time': 0.015421080999658443, 'acc': 19.75, 'val_loss': 1.660806336402893, 'val_acc': 22.0}
                {'epoch': 3, 'loss': 1.6923445892333984, 'time': 0.01363091799794347, 'acc': 19.625, 'val_loss': 1.6550078630447387, 'val_acc': 22.5}
                ...

        """
        self._transfer_optimizer_state_to_right_device()

        if verbose:
            callbacks = [ProgressionCallback()] + callbacks
        callback_list = CallbackList(callbacks)
        callback_list.set_model(self)

        self.stop_training = False
        epoch_iterator = EpochIterator(train_generator,
                                       valid_generator,
                                       epochs=epochs,
                                       steps_per_epoch=steps_per_epoch,
                                       validation_steps=validation_steps,
                                       initial_epoch=initial_epoch,
                                       callback=callback_list,
                                       metrics_names=self.metrics_names)

        for train_step_iterator, valid_step_iterator in epoch_iterator:
            self.model.train(True)
            with torch.enable_grad():
                for step, (x, y) in train_step_iterator:
                    step.loss, step.metrics, _ = self._fit_batch(
                        x, y, callback=callback_list, step=step.number)
                    step.size = self._get_batch_size(x, y)

            if valid_step_iterator is not None:
                self._validate(valid_step_iterator)

            epoch_iterator.stop_training = self.stop_training

        return epoch_iterator.epoch_logs
Exemple #9
0
def main(args):
    # Set up logging and devices
    args.save_dir = util.get_save_dir(args.save_dir, args.name, training=True)
    log = util.get_logger(args.save_dir, args.name)
    tbx = SummaryWriter(args.save_dir)
    device, args.gpu_ids = util.get_available_devices()
    log.info('Args: {}'.format(dumps(vars(args), indent=4, sort_keys=True)))
    args.batch_size *= max(1, len(args.gpu_ids))

    # Set random seed
    log.info('Using random seed {}...'.format(args.seed))
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    # Get embeddings
    log.info('Loading embeddings...')
    word_vectors = util.torch_from_json(args.word_emb_file)

    # Get model
    log.info('Building model...')
    '''
    model = BiDAF(word_vectors=word_vectors,
                  hidden_size=args.hidden_size,
                  drop_prob=args.drop_prob)
    '''
    model = BertGMV(device)
    #model = Squad2Model.from_pretrained("bert-base-uncased",
    #            cache_dir=os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE)))

    model = nn.DataParallel(model, args.gpu_ids)
    
    if args.load_path:
        log.info('Loading checkpoint from {}...'.format(args.load_path))
        model, step = util.load_model(model, args.load_path, args.gpu_ids)
    else:
        step = 0
    model = model.to(device)
    model.train()
    ema = util.EMA(model, args.ema_decay)

    # Get saver
    saver = util.CheckpointSaver(args.save_dir,
                                 max_checkpoints=args.max_checkpoints,
                                 metric_name=args.metric_name,
                                 maximize_metric=args.maximize_metric,
                                 log=log)

    # Get optimizer and scheduler
    optimizer = optim.Adadelta(model.parameters(), args.lr,
                               weight_decay=args.l2_wd)
    scheduler = sched.LambdaLR(optimizer, lambda s: 1.)  # Constant LR

    # Get data loader
    log.info('Building dataset...')
    train_dataset = SQuAD(args.train_record_file, args.use_squad_v2)
    train_loader = data.DataLoader(train_dataset,
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.num_workers,
                                   collate_fn=collate_fn)
    dev_dataset = SQuAD(args.dev_record_file, args.use_squad_v2)
    dev_loader = data.DataLoader(dev_dataset,
                                 batch_size=args.batch_size,
                                 shuffle=False,
                                 num_workers=args.num_workers,
                                 collate_fn=collate_fn)

    # Train
    log.info('Training...')
    steps_till_eval = args.eval_steps
    epoch = step // len(train_dataset)
    while epoch != args.num_epochs:
        epoch += 1
        log.info('Starting epoch {}...'.format(epoch))
        with torch.enable_grad(), \
                tqdm(total=len(train_loader.dataset)) as progress_bar:
            for cw_idxs, cc_idxs, qw_idxs, qc_idxs, y1, y2, ids in train_loader:
                # Setup for forward
                cw_idxs = cw_idxs.to(device)
                qw_idxs = qw_idxs.to(device)
                batch_size = cw_idxs.size(0)
                optimizer.zero_grad()

                # Forward
                log_p1, log_p2 = model(cw_idxs, qw_idxs)
                y1, y2 = y1.to(device), y2.to(device)
                loss = F.nll_loss(log_p1, y1) + F.nll_loss(log_p2, y2)
                loss_val = loss.item()

                # Backward
                loss.backward()
                nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
                optimizer.step()
                scheduler.step(step // batch_size)
                ema(model, step // batch_size)

                # Log info
                step += batch_size
                progress_bar.update(batch_size)
                progress_bar.set_postfix(epoch=epoch,
                                         NLL=loss_val)
                tbx.add_scalar('train/NLL', loss_val, step)
                tbx.add_scalar('train/LR',
                               optimizer.param_groups[0]['lr'],
                               step)

                steps_till_eval -= batch_size
                if steps_till_eval <= 0:
                    steps_till_eval = args.eval_steps

                    # Evaluate and save checkpoint
                    log.info('Evaluating at step {}...'.format(step))
                    ema.assign(model)
                    results, pred_dict = evaluate(model, dev_loader, device,
                                                  args.dev_eval_file,
                                                  args.max_ans_len,
                                                  args.use_squad_v2)
                    saver.save(step, model, results[args.metric_name], device)
                    ema.resume(model)

                    # Log to console
                    results_str = ', '.join('{}: {:05.2f}'.format(k, v)
                                            for k, v in results.items())
                    log.info('Dev {}'.format(results_str))

                    # Log to TensorBoard
                    log.info('Visualizing in TensorBoard...')
                    for k, v in results.items():
                        tbx.add_scalar('dev/{}'.format(k), v, step)
                    util.visualize(tbx,
                                   pred_dict=pred_dict,
                                   eval_path=args.dev_eval_file,
                                   step=step,
                                   split='dev',
                                   num_visuals=args.num_visuals)
Exemple #10
0
    def step(self, closure=None):
        """Performs a single optimization step.

        Args:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        grads = []
        states = []
        params_with_grad = []
        step_sizes = []

        for group in self.param_groups:
            for p in group['params']:
                etaminus, etaplus = group['etas']
                step_size_min, step_size_max = group['step_sizes']

                if p.grad is not None:
                    if p.grad.is_sparse:
                        raise RuntimeError(
                            'RMSprop does not support sparse gradients')

                    grads.append(p.grad)
                    params_with_grad.append(p)

                    state = self.state[p]
                    # State initialization
                    if len(state) == 0:
                        state['step'] = 0
                        state['prev'] = torch.zeros_like(
                            p, memory_format=torch.preserve_format)
                        state['step_size'] = p.grad.new().resize_as_(
                            p.grad).fill_(group['lr'])

                        state['step'] += 1

                    states.append(state)
                    step_sizes.append(state['step_size'])

            signs = torch._foreach_mul(grads, [s['prev'] for s in states])
            signs = [s.sign() for s in signs]
            for sign in signs:
                sign[sign.gt(0)] = etaplus
                sign[sign.lt(0)] = etaminus
                sign[sign.eq(0)] = 1

            # update stepsizes with step size updates
            torch._foreach_mul_(step_sizes, signs)
            for step_size in step_sizes:
                step_size.clamp_(step_size_min, step_size_max)

            # for dir<0, dfdx=0
            # for dir>=0 dfdx=dfdx
            for i in range(len(grads)):
                grads[i] = grads[i].clone(memory_format=torch.preserve_format)
                grads[i][signs[i].eq(etaminus)] = 0

            # update parameters
            grad_signs = [grad.sign() for grad in grads]
            torch._foreach_addcmul_(params_with_grad,
                                    grad_signs,
                                    step_sizes,
                                    value=-1)

            for i in range(len(states)):
                states[i]['prev'].copy_(grads[i])

        return loss
Exemple #11
0
    def train_step(self, use_gpu, local_variables=None):
        """Train step to be executed in train loop

        Args:
            use_gpu: if true, execute training on GPU
            local_variables: Dict containing intermediate values
                in train_step for access by hooks
        """

        if local_variables is None:
            local_variables = {}

        # Process next sample
        sample = next(self.get_data_iterator())
        local_variables["sample"] = sample

        assert (
            isinstance(local_variables["sample"], dict)
            and "input" in local_variables["sample"]
            and "target" in local_variables["sample"]), (
                f"Returned sample [{sample}] is not a map with 'input' and" +
                "'target' keys")

        # Copy sample to GPU
        local_variables["target"] = local_variables["sample"]["target"]
        if use_gpu:
            for key, value in local_variables["sample"].items():
                local_variables["sample"][key] = recursive_copy_to_gpu(
                    value, non_blocking=True)

        with torch.enable_grad():
            # Forward pass
            local_variables["output"] = self.model(
                local_variables["sample"]["input"])

            local_variables["local_loss"] = self.compute_loss(
                local_variables["output"], local_variables["sample"])

            local_variables["loss"] = local_variables["local_loss"].detach(
            ).clone()
            local_variables["loss"] = all_reduce_mean(local_variables["loss"])

            self.losses.append(local_variables["loss"].data.cpu().item() *
                               local_variables["target"].size(0))

            self.update_meters(local_variables["output"],
                               local_variables["sample"])

        # Run backwards pass / update optimizer
        if self.amp_opt_level is not None:
            self.optimizer.zero_grad()
            with apex.amp.scale_loss(local_variables["local_loss"],
                                     self.optimizer.optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            self.optimizer.backward(local_variables["local_loss"])

        self.optimizer.update_schedule_on_step(self.where)
        self.optimizer.step()

        self.num_updates += self.get_global_batchsize()
Exemple #12
0
    def trades_loss(self,
                    model,
                    x_natural,
                    y,
                    optimizer,
                    step_size=0.003,
                    epsilon=0.031,
                    perturb_steps=10,
                    beta=1.0,
                    distance='l_inf'):
        # define KL-loss
        criterion_kl = nn.KLDivLoss(size_average=False)
        model.eval()
        batch_size = len(x_natural)
        # generate adversarial example
        x_adv = x_natural.detach() + 0.001 * torch.randn(x_natural.shape).cuda().detach()

        if distance == 'l_inf':
            for _ in range(perturb_steps):
                x_adv.requires_grad_()
                with torch.enable_grad():
                    loss_kl = criterion_kl(F.log_softmax(model(x_adv), dim=1),
                                        F.softmax(model(x_natural), dim=1))
                grad = torch.autograd.grad(loss_kl, [x_adv])[0]
                x_adv = x_adv.detach() + step_size * torch.sign(grad.detach())
                x_adv = torch.min(torch.max(x_adv, x_natural - epsilon), x_natural + epsilon)
                x_adv = torch.clamp(x_adv, 0.0, 1.0)

        elif distance == 'l_2':
            delta = 0.001 * torch.randn(x_natural.shape).cuda().detach()
            delta = Variable(delta.data, requires_grad=True)

            # Setup optimizers
            optimizer_delta = optim.SGD([delta], lr=epsilon / perturb_steps * 2)

            for _ in range(perturb_steps):
                adv = x_natural + delta

                # optimize
                optimizer_delta.zero_grad()
                with torch.enable_grad():
                    loss = (-1) * criterion_kl(F.log_softmax(model(adv), dim=1),
                                            F.softmax(model(x_natural), dim=1))
                loss.backward()
                # renorming gradient
                grad_norms = delta.grad.view(batch_size, -1).norm(p=2, dim=1)
                delta.grad.div_(grad_norms.view(-1, 1, 1, 1))
                # avoid nan or inf if gradient is 0
                if (grad_norms == 0).any():
                    delta.grad[grad_norms == 0] = torch.randn_like(delta.grad[grad_norms == 0])
                optimizer_delta.step()

                # projection
                delta.data.add_(x_natural)
                delta.data.clamp_(0, 1).sub_(x_natural)
                delta.data.renorm_(p=2, dim=0, maxnorm=epsilon)
            x_adv = Variable(x_natural + delta, requires_grad=False)
        else:
            x_adv = torch.clamp(x_adv, 0.0, 1.0)
        model.train()

        x_adv = Variable(torch.clamp(x_adv, 0.0, 1.0), requires_grad=False)
        # zero gradient
        optimizer.zero_grad()
        # calculate robust loss
        logits = model(x_natural)
        loss_natural = F.cross_entropy(logits, y)
        loss_robust = (1.0 / batch_size) * criterion_kl(F.log_softmax(model(x_adv), dim=1),
                                                        F.softmax(model(x_natural), dim=1))
        loss = loss_natural + beta * loss_robust
        return loss
Exemple #13
0
    def step(self, closure=None):
        """Performs a single optimization step.

        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                grad = p.grad
                if grad.is_sparse:
                    raise RuntimeError(
                        'RMSprop does not support sparse gradients')
                state = self.state[p]

                # State initialization
                if len(state) == 0:
                    state['step'] = 0
                    state['square_avg'] = torch.zeros_like(
                        p, memory_format=torch.preserve_format)
                    if group['momentum'] > 0:
                        state['momentum_buffer'] = torch.zeros_like(
                            p, memory_format=torch.preserve_format)
                    if group['centered']:
                        state['grad_avg'] = torch.zeros_like(
                            p, memory_format=torch.preserve_format)

                square_avg = state['square_avg']
                alpha = group['alpha']

                state['step'] += 1

                if group['weight_decay'] != 0:
                    grad = grad.add(p, alpha=group['weight_decay'])

                square_avg.mul_(alpha).addcmul_(grad, grad, value=1 - alpha)

                if group['centered']:
                    grad_avg = state['grad_avg']
                    grad_avg.mul_(alpha).add_(grad, alpha=1 - alpha)
                    avg = square_avg.addcmul(grad_avg, grad_avg,
                                             value=-1).sqrt_().add_(
                                                 group['eps'])
                else:
                    avg = square_avg.sqrt().add_(group['eps'])

                if group['momentum'] > 0:
                    buf = state['momentum_buffer']
                    buf.mul_(group['momentum']).addcdiv_(grad, avg)
                    # Need to avoid version tracking for parameter.
                    p.data.add_(buf, alpha=-group['lr'])
                else:
                    # Need to avoid version tracking for parameter.
                    p.data.addcdiv_(grad, avg, value=-group['lr'])

        return loss
Exemple #14
0
    def train(self, training_batch) -> None:
        if isinstance(training_batch, TrainingDataPage):
            if self.maxq_learning:
                training_batch = training_batch.as_parametric_maxq_training_batch(
                )
            else:
                training_batch = training_batch.as_parametric_sarsa_training_batch(
                )

        learning_input = training_batch.training_input
        self.minibatch += 1

        reward = learning_input.reward
        not_done_mask = learning_input.not_terminal

        discount_tensor = torch.full_like(reward, self.gamma)
        if self.use_seq_num_diff_as_time_diff:
            assert self.multi_steps is None
            discount_tensor = torch.pow(self.gamma,
                                        learning_input.time_diff.float())
        if self.multi_steps is not None:
            discount_tensor = torch.pow(self.gamma,
                                        learning_input.step.float())

        if self.maxq_learning:
            all_next_q_values, all_next_q_values_target = self.get_detached_q_values(
                learning_input.tiled_next_state,
                learning_input.possible_next_actions)
            # Compute max a' Q(s', a') over all possible actions using target network
            next_q_values, _ = self.get_max_q_values_with_target(
                all_next_q_values.q_value,
                all_next_q_values_target.q_value,
                learning_input.possible_next_actions_mask.float(),
            )
        else:
            # SARSA (Use the target network)
            _, next_q_values = self.get_detached_q_values(
                learning_input.next_state, learning_input.next_action)
            next_q_values = next_q_values.q_value

        filtered_max_q_vals = next_q_values * not_done_mask.float()

        target_q_values = reward + (discount_tensor * filtered_max_q_vals)

        with torch.enable_grad():
            # Get Q-value of action taken
            current_state_action = rlt.StateAction(
                state=learning_input.state, action=learning_input.action)
            q_values = self.q_network(current_state_action).q_value
            self.all_action_scores = q_values.detach()

            value_loss = self.q_network_loss(q_values, target_q_values)
            self.loss = value_loss.detach()

            value_loss.backward()
            self._maybe_run_optimizer(self.q_network_optimizer,
                                      self.minibatches_per_step)

        # Use the soft update rule to update target network
        self._maybe_soft_update(self.q_network, self.q_network_target,
                                self.tau, self.minibatches_per_step)

        with torch.enable_grad():
            # get reward estimates
            reward_estimates = self.reward_network(
                current_state_action).q_value
            reward_loss = F.mse_loss(reward_estimates, reward)
            reward_loss.backward()
            self._maybe_run_optimizer(self.reward_network_optimizer,
                                      self.minibatches_per_step)

        self.loss_reporter.report(
            td_loss=self.loss,
            reward_loss=reward_loss,
            model_values_on_logged_actions=self.all_action_scores,
        )
Exemple #15
0
    def backward_pass(self, y, n, dy, dn, mask=None, msa_mask=None, **kwargs):
        n1, n2 = torch.chunk(n, 2, dim=2)
        del n

        dn1, dn2 = torch.chunk(dn, 2, dim=2)
        del dn

        y1, y2 = torch.chunk(y, 2, dim=2)
        del y

        dy1, dy2 = torch.chunk(dy, 2, dim=2)
        del dy

        with torch.enable_grad():
            n1.requires_grad = True
            gn1 = self.k(n1, set_rng=True)
            torch.autograd.backward(gn1, dn2)

        with torch.no_grad():
            m2 = n2 - gn1
            del n2, gn1

            dm1 = dn1 + n1.grad
            del dn1
            n1.grad = None

        with torch.enable_grad():
            m2.requires_grad = True
            y2.requires_grad = True
            fm2 = self.j(m2,
                         y2,
                         set_rng=True,
                         mask=msa_mask,
                         context_mask=mask)
            torch.autograd.backward(fm2, dm1)

        with torch.no_grad():
            m1 = n1 - fm2
            del n1, fm2

            dm2 = dn2 + m2.grad
            dx2 = dy2 + y2.grad
            del dn2
            del dy2
            m2.grad = None
            y2.grad = None

        with torch.enable_grad():
            y1.requires_grad = True
            gy1 = self.g(y1, set_rng=True)
            torch.autograd.backward(gy1, dx2)

        with torch.no_grad():
            x2 = y2 - gy1
            del y2, gy1

            dx1 = dy1 + y1.grad
            del dy1
            y1.grad = None

        with torch.enable_grad():
            x2.requires_grad = True
            m2.requires_grad = True
            fx2 = self.f(x2,
                         m2,
                         set_rng=True,
                         mask=mask,
                         context_mask=msa_mask)
            torch.autograd.backward(fx2, dx1)

        with torch.no_grad():
            x1 = y1 - fx2
            del y1, fx2

            dx2 = dx2 + x2.grad
            dm2 = dm2 + m2.grad
            x2.grad = None
            m2.grad = None

        with torch.no_grad():
            m = torch.cat([m1, m2.detach()], dim=2)
            dm = torch.cat([dm1, dm2], dim=2)

            x = torch.cat([x1, x2.detach()], dim=2)
            dx = torch.cat([dx1, dx2], dim=2)

        return x, m, dx, dm
    def train(self, args):
        with open(args.tr_list, 'r') as f:
            self.tr_list = [line.strip() for line in f.readlines()]
        self.tr_size = len(self.tr_list)
        self.cv_file = args.cv_file
        self.ckpt_dir = args.ckpt_dir
        self.logging_period = args.logging_period
        self.resume_model = args.resume_model
        self.time_log = args.time_log
        self.lr = args.lr
        self.lr_decay_factor = args.lr_decay_factor
        self.lr_decay_period = args.lr_decay_period
        self.clip_norm = args.clip_norm
        self.max_n_epochs = args.max_n_epochs
        self.batch_size = args.batch_size
        self.buffer_size = args.buffer_size
        self.loss_log = args.loss_log
        self.unit = args.unit
        self.segment_size = args.segment_size
        self.segment_shift = args.segment_shift

        self.gpu_ids = tuple(map(int, args.gpu_ids.split(',')))
        if len(self.gpu_ids) == 1 and self.gpu_ids[0] == -1:
            # cpu only
            self.device = torch.device('cpu')
        else:
            # gpu
            self.device = torch.device('cuda:{}'.format(self.gpu_ids[0]))

        if not os.path.isdir(self.ckpt_dir):
            os.makedirs(self.ckpt_dir)

        logger = getLogger(os.path.join(self.ckpt_dir, 'train.log'), log_file=True)
        
        # create data loaders for training and cross validation
        tr_loader = AudioLoader(self.tr_list, self.sample_rate, self.unit,
                                self.segment_size, self.segment_shift,
                                self.batch_size, self.buffer_size,
                                self.in_norm, mode='train')
        cv_loader = AudioLoader(self.cv_file, self.sample_rate, unit='utt',
                                segment_size=None, segment_shift=None,
                                batch_size=1, buffer_size=10,
                                in_norm=self.in_norm, mode='eval')

        # create a network
        net = Net()
        logger.info('Model summary:\n{}'.format(net))

        net = net.to(self.device)
        if len(self.gpu_ids) > 1:
            net = DataParallel(net, device_ids=self.gpu_ids)

        # calculate model size
        param_count = numParams(net)
        logger.info('Trainable parameter count: {:,d} -> {:.2f} MB\n'.format(param_count, param_count*32/8/(2**20)))

        # net feeder
        feeder = NetFeeder(self.device, self.win_size, self.hop_size)

        # training criterion and optimizer
        criterion = LossFunction()
        optimizer = Adam(net.parameters(), lr=self.lr, amsgrad=True)
        scheduler = lr_scheduler.StepLR(optimizer, step_size=self.lr_decay_period, gamma=self.lr_decay_factor)
        
        # resume model if needed
        if self.resume_model:
            logger.info('Resuming model from {}'.format(self.resume_model))
            ckpt = CheckPoint()
            ckpt.load(self.resume_model, self.device)
            state_dict = {}
            for key in ckpt.net_state_dict:
                if len(self.gpu_ids) > 1:
                    state_dict['module.'+key] = ckpt.net_state_dict[key]
                else:
                    state_dict[key] = ckpt.net_state_dict[key]
            net.load_state_dict(state_dict)
            optimizer.load_state_dict(ckpt.optim_state_dict)
            ckpt_info = ckpt.ckpt_info
            logger.info('model info: epoch {}, iter {}, cv_loss - {:.4f}\n'.format(ckpt.ckpt_info['cur_epoch']+1,
                ckpt.ckpt_info['cur_iter']+1, ckpt.ckpt_info['cv_loss']))
        else:
            logger.info('Training from scratch...\n')
            ckpt_info = {'cur_epoch': 0,
                         'cur_iter': 0,
                         'tr_loss': None,
                         'cv_loss': None,
                         'best_loss': float('inf')}
        
        start_iter = 0
        # train model
        while ckpt_info['cur_epoch'] < self.max_n_epochs:
            accu_tr_loss = 0.
            accu_n_frames = 0
            net.train()
            for n_iter, egs in enumerate(tr_loader):
                n_iter += start_iter
                mix = egs['mix']
                sph = egs['sph']
                n_samples = egs['n_samples']

                mix = mix.to(self.device)
                sph = sph.to(self.device)
                n_samples = n_samples.to(self.device)

                n_frames = countFrames(n_samples, self.win_size, self.hop_size)

                start_time = timeit.default_timer()
                
                # prepare features and labels
                feat, lbl = feeder(mix, sph)
                loss_mask = lossMask(shape=lbl.shape, n_frames=n_frames, device=self.device)
                # forward + backward + optimize
                optimizer.zero_grad()
                with torch.enable_grad():
                    est = net(feat)
                loss = criterion(est, lbl, loss_mask, n_frames)
                loss.backward()
                if self.clip_norm >= 0.0:
                    clip_grad_norm_(net.parameters(), self.clip_norm)
                optimizer.step()
                # calculate loss
                running_loss = loss.data.item()
                accu_tr_loss += running_loss * sum(n_frames)
                accu_n_frames += sum(n_frames)

                end_time = timeit.default_timer()
                batch_time = end_time - start_time

                if self.time_log:
                    with open(self.time_log, 'a+') as f:
                        print('Epoch [{}/{}], Iter [{}], tr_loss = {:.4f} / {:.4f}, batch_time (s) = {:.4f}'.format(ckpt_info['cur_epoch']+1,
                            self.max_n_epochs, n_iter, running_loss, accu_tr_loss / accu_n_frames, batch_time), file=f)
                        f.flush()
                else:
                    print('Epoch [{}/{}], Iter [{}], tr_loss = {:.4f} / {:.4f}, batch_time (s) = {:.4f}'.format(ckpt_info['cur_epoch']+1,
                        self.max_n_epochs, n_iter, running_loss, accu_tr_loss / accu_n_frames, batch_time), flush=True)
 
        
                if (n_iter + 1) % self.logging_period == 0:
                    avg_tr_loss = accu_tr_loss / accu_n_frames
                    avg_cv_loss = self.validate(net, cv_loader, criterion, feeder)
                    net.train()
                
                    ckpt_info['cur_iter'] = n_iter
                    is_best = True if avg_cv_loss < ckpt_info['best_loss'] else False
                    ckpt_info['best_loss'] = avg_cv_loss if is_best else ckpt_info['best_loss']
                    latest_model = 'latest.pt'
                    best_model = 'best.pt'
                    ckpt_info['tr_loss'] = avg_tr_loss
                    ckpt_info['cv_loss'] = avg_cv_loss
                    if len(self.gpu_ids) > 1:
                        ckpt = CheckPoint(ckpt_info, net.module.state_dict(), optimizer.state_dict())
                    else:
                        ckpt = CheckPoint(ckpt_info, net.state_dict(), optimizer.state_dict())
                    logger.info('Saving checkpoint into {}'.format(os.path.join(self.ckpt_dir, latest_model)))
                    if is_best:
                        logger.info('Saving checkpoint into {}'.format(os.path.join(self.ckpt_dir, best_model)))
                    logger.info('Epoch [{}/{}], ( tr_loss: {:.4f} | cv_loss: {:.4f} )\n'.format(ckpt_info['cur_epoch']+1,
                        self.max_n_epochs, avg_tr_loss, avg_cv_loss))
                    
                    model_path = os.path.join(self.ckpt_dir, 'models')
                    if not os.path.isdir(model_path):
                        os.makedirs(model_path)

                    ckpt.save(os.path.join(model_path, latest_model),
                              is_best,
                              os.path.join(model_path, best_model))
                    
                    lossLog(os.path.join(self.ckpt_dir, self.loss_log), ckpt, self.logging_period)
            
                    accu_tr_loss = 0.
                    accu_n_frames = 0

                    if n_iter + 1 == self.tr_size // self.batch_size:
                        start_iter = 0
                        ckpt_info['cur_iter'] = 0
                        break
                    
            ckpt_info['cur_epoch'] += 1
            scheduler.step() # learning rate decay
        
        return
Exemple #17
0
    def step(self, closure=None):
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        torch.cuda.synchronize()

        for group_id, group in enumerate(self.param_groups):
            for param_id, p in enumerate(group["params"]):
                if p.grad is None:
                    continue

                state = self.state[p]
                if len(state) == 0:
                    state["step"] = 0
                    dtype = torch.float16 if self.use_fp16_stats else p.data.dtype
                    # gradient momentums
                    state["exp_avg"] = torch.zeros_like(p.data,
                                                        dtype=dtype,
                                                        device="cpu")
                    # gradient variances
                    state["exp_avg_sq"] = torch.zeros_like(p.data,
                                                           dtype=dtype,
                                                           device="cpu")
                    if self.use_fp16_stats:
                        assert torch.is_floating_point(p.data)
                        state["exp_avg_scale"] = 1.0
                        state["exp_avg_sq_scale"] = 1.0

                exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]

                p_data_bak = p.data  # backup of the original data pointer

                p.data = p.data.to(dtype=torch.float32, device="cpu")
                p.grad.data = p.grad.data.to(dtype=torch.float32, device="cpu")

                if self.use_fp16_stats:
                    exp_avg = exp_avg.float() * state["exp_avg_scale"]
                    exp_avg_sq = exp_avg_sq.float() * state["exp_avg_sq_scale"]

                state["step"] += 1
                beta1, beta2 = group["betas"]

                self.ds_opt_adam.adam_update(
                    self.opt_id,
                    state["step"],
                    group["lr"],
                    beta1,
                    beta2,
                    group["eps"],
                    group["weight_decay"],
                    group["bias_correction"],
                    p.data,
                    p.grad.data,
                    exp_avg,
                    exp_avg_sq,
                )

                if p_data_bak.data_ptr() != p.data.data_ptr():
                    p_data_bak.copy_(p.data)
                    p.data = p_data_bak

                if self.use_fp16_stats:

                    def inf_norm(t):
                        return torch.norm(t, float("inf"))

                    # from github.com/openai/jukebox/blob/master/jukebox/utils/fp16.py
                    state["exp_avg_scale"], state["exp_avg_sq_scale"] = (
                        1e-8 + inf_norm(exp_avg) / self.FLOAT16_MAX,
                        1e-8 + inf_norm(exp_avg_sq) / self.FLOAT16_MAX,
                    )
                    state["exp_avg"], state["exp_avg_sq"] = (
                        (exp_avg / state["exp_avg_scale"]).half(),
                        (exp_avg_sq / state["exp_avg_sq_scale"]).half(),
                    )

        return loss
Exemple #18
0
    def step(self, closure=None, fp16_param_groups=None):
        """Update the model parameters.

        .. note::
            This method will be called internally by ZeRO-Offload. DeepSpeed
            users should still use ``engine.step()`` as shown in the
            `Getting Started
            <https://www.deepspeed.ai/getting-started/#training>`_ guide.

        Args:
            closure (callable, optional): closure to compute the loss.
                Defaults to ``None``.
            fp16_param_groups: FP16 GPU parameters to update. Performing the
                copy here reduces communication time. Defaults to ``None``.

        Returns:
            loss: if ``closure`` is provided. Otherwise ``None``.
        """

        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group_id, group in enumerate(self.param_groups):
            for param_id, p in enumerate(group['params']):

                if p.grad is None:
                    continue

                state = self.state[p]
                # State initialization
                if len(state) == 0:
                    #print(f'group {group_id} param {param_id} = {p.numel()}')
                    state['step'] = 0
                    # gradient momentums
                    state['exp_avg'] = torch.zeros_like(p.data,
                                                        dtype=p.dtype,
                                                        device='cpu')
                    #memory_format=torch.preserve_format)
                    # gradient variances
                    state['exp_avg_sq'] = torch.zeros_like(p.data,
                                                           dtype=p.dtype,
                                                           device='cpu')
                    #memory_format=torch.preserve_format)

                state['step'] += 1
                beta1, beta2 = group['betas']

                if fp16_param_groups is not None:
                    self.ds_opt_adam.adam_update_copy(
                        self.opt_id,
                        state['step'],
                        group['lr'],
                        beta1,
                        beta2,
                        group['eps'],
                        group['weight_decay'],
                        group['bias_correction'],
                        p.data,
                        p.grad.data,
                        state['exp_avg'],
                        state['exp_avg_sq'],
                        fp16_param_groups[group_id][param_id].data)
                else:
                    self.ds_opt_adam.adam_update(self.opt_id,
                                                 state['step'],
                                                 group['lr'],
                                                 beta1,
                                                 beta2,
                                                 group['eps'],
                                                 group['weight_decay'],
                                                 group['bias_correction'],
                                                 p.data,
                                                 p.grad.data,
                                                 state['exp_avg'],
                                                 state['exp_avg_sq'])
        return loss
    def train_step(self):
        """Train step to be executed in train loop."""

        self.last_batch = None

        # Process next sample
        with Timer() as timer:
            sample = next(self.data_iterator)

        assert isinstance(sample, dict) and "input" in sample and "target" in sample, (
            f"Returned sample [{sample}] is not a map with 'input' and"
            + "'target' keys"
        )

        # Copy sample to GPU
        target = sample["target"]
        if self.use_gpu:
            sample = recursive_copy_to_gpu(sample, non_blocking=True)

        if self.mixup_transform is not None:
            sample = self.mixup_transform(sample)

        # Optional Pytorch AMP context
        torch_amp_context = (
            torch.cuda.amp.autocast()
            if self.amp_type == AmpType.PYTORCH
            else contextlib.suppress()
        )

        # only sync with DDP when we need to perform an optimizer step
        # an optimizer step can be skipped if gradient accumulation is enabled
        do_step = self._should_do_step()
        ctx_mgr_model = (
            self.distributed_model.no_sync()
            if self.distributed_model is not None and not do_step
            else contextlib.suppress()
        )
        ctx_mgr_loss = (
            self.distributed_loss.no_sync()
            if self.distributed_loss is not None and not do_step
            else contextlib.suppress()
        )

        with ctx_mgr_model, ctx_mgr_loss:
            # Forward pass
            with torch.enable_grad(), torch_amp_context:
                output = self.model(sample["input"])

                local_loss = self.compute_loss(output, sample)
                loss = local_loss.detach().clone()
                self.losses.append(loss.data.cpu().item())

                self.update_meters(output, sample)

            # Backwards pass + optimizer step
            self.run_optimizer(local_loss)

        self.num_updates += self.get_global_batchsize()

        # Move some data to the task so hooks get a chance to access it
        self.last_batch = LastBatchInfo(
            loss=loss,
            output=output,
            target=target,
            sample=sample,
            step_data={"sample_fetch_time": timer.elapsed_time},
        )
Exemple #20
0
    def forward(ctx, X, Y):

        with tr.enable_grad():
            c = forward_quaternion_X_times_Y_inv(X, Y)
        ctx.save_for_backward(X, Y)
        return c
def trades_loss(model,
                loss_fn,
                x_natural,
                y,
                norm,
                optimizer,
                step_size=0.003,
                epsilon=0.031,
                perturb_steps=10,
                beta=1.0,
                version=None,
                device="gpu"):
    # define KL-loss
    #criterion_kl = nn.KLDivLoss(size_average=False)
    if version is not None and "plus" in version:
        criterion_kl = nn.KLDivLoss(reduction='none')
    else:
        criterion_kl = nn.KLDivLoss(reduction='sum')
    model.eval()
    batch_size = len(x_natural)
    # generate adversarial example
    if norm == np.inf:
        x_adv = x_natural.detach() + 0.001 * torch.randn(
            x_natural.shape).to(device).detach()
        for _ in range(perturb_steps):
            x_adv.requires_grad_()
            with torch.enable_grad():
                loss_kl = criterion_kl(F.log_softmax(model(x_adv), dim=1),
                                       F.softmax(model(x_natural), dim=1))
                if version is not None and "plus" in version:
                    loss_kl = torch.sum(loss_kl, dim=1) \
                            / torch.norm(torch.flatten(x_adv - x_natural, start_dim=1), p=norm, dim=1)
                    loss_kl = torch.sum(loss_kl)
            grad = torch.autograd.grad(loss_kl, [x_adv])[0]
            x_adv = x_adv.detach() + step_size * torch.sign(grad.detach())
            x_adv = torch.min(torch.max(x_adv, x_natural - epsilon),
                              x_natural + epsilon)
            x_adv = torch.clamp(x_adv, 0.0, 1.0)
    elif norm == 2:
        delta = 0.001 * torch.randn(x_natural.shape).to(device).detach()
        delta = Variable(delta.data, requires_grad=True)

        # Setup optimizers
        optimizer_delta = optim.SGD([delta], lr=epsilon / perturb_steps * 2)

        for _ in range(perturb_steps):
            adv = x_natural + delta

            # optimize
            optimizer_delta.zero_grad()
            with torch.enable_grad():
                loss = (-1) * criterion_kl(F.log_softmax(model(adv), dim=1),
                                           F.softmax(model(x_natural), dim=1))
                if version is not None and "plus" in version:
                    loss_kl = torch.sum(loss_kl, dim=1) \
                            / torch.norm(torch.flatten(x_adv - x_natural, start_dim=1), p=norm, dim=1)
                    loss = torch.sum(loss_kl)
            loss.backward()
            # renorming gradient
            grad_norms = delta.grad.view(batch_size, -1).norm(p=2, dim=1)
            delta.grad.div_(grad_norms.view(-1, 1, 1, 1))
            # avoid nan or inf if gradient is 0
            if (grad_norms == 0).any():
                delta.grad[grad_norms == 0] = torch.randn_like(
                    delta.grad[grad_norms == 0])
            optimizer_delta.step()

            # projection
            delta.data.add_(x_natural)
            delta.data.clamp_(0, 1).sub_(x_natural)
            delta.data.renorm_(p=2, dim=0, maxnorm=epsilon)
        x_adv = Variable(x_natural + delta, requires_grad=False)
    else:
        x_adv = torch.clamp(x_adv, 0.0, 1.0)
    model.train()

    x_adv = Variable(torch.clamp(x_adv, 0.0, 1.0), requires_grad=False)
    #x_adv = Variable(x_adv, requires_grad=False)
    # zero gradient
    optimizer.zero_grad()
    # calculate robust loss
    #outputs = F.softmax(model(x_natural), dim=1)
    outputs = model(x_natural)
    loss_natural = loss_fn(outputs, y)
    if version is not None and "plus" in version:
        loss_kl = criterion_kl(F.log_softmax(model(x_adv), dim=1),
                               F.softmax(model(x_natural), dim=1))
        loss_kl = torch.sum(loss_kl, dim=1) \
                / torch.norm(torch.flatten(x_adv - x_natural, start_dim=1), p=norm, dim=1)
        loss_robust = (1.0 / batch_size) * torch.sum(loss_kl)
    else:
        loss_robust = (1.0 / batch_size) * criterion_kl(
            F.log_softmax(model(x_adv), dim=1),
            F.softmax(model(x_natural), dim=1))
    if version is not None and "sum" in version:
        loss = loss_natural + beta * batch_size * loss_robust
    else:
        loss = loss_natural + beta * loss_robust
    return outputs, loss
    def step(self, closure=None):
        """Performs a single optimization step.

        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            lr = group['lr']
            weight_decay = group['weight_decay']
            momentum = group['momentum']
            dampening = group['dampening']
            nesterov = group['nesterov']

            for p in group['params']:
                if p.grad is None:
                    continue
                grad = p.grad

                grad = grad.coalesce()
                grad_inds = grad._indices()
                grad_values = grad._values()
                size = grad.size()

                def make_sparse(values):
                    constructor = grad.new
                    if grad_inds.dim() == 0 or values.dim() == 0:
                        return constructor().resize_as_(grad)
                    return constructor(grad_inds,
                                       values.reshape(grad_values.shape), size)

                if momentum != 0:
                    param_state = self.state[p]

                    if "momentum_buffer" not in param_state:
                        buf = param_state["momentum_buffer"] = torch.clone(
                            grad).detach().to_dense()
                    else:
                        buf = param_state["momentum_buffer"]
                        # Only update momentum_buffer where sparse gradient is non-zero
                        buf[grad_inds].mul_(momentum)
                        buf.add_(grad, alpha=(1 - dampening))

                    mom_values = buf[grad_inds].squeeze()

                    if nesterov:
                        mom_values = grad_values.add(mom_values,
                                                     alpha=momentum)

                    p.data.add_(make_sparse(mom_values), alpha=-lr)
                else:
                    p.add_(grad, alpha=-lr)

                if weight_decay != 0:
                    p.add_(p.sparse_mask(grad), alpha=-lr * weight_decay)

        return loss
Exemple #23
0
    def _train_step(self) -> None:
        """Run a training step over the training data."""
        self.model.train()
        metrics_with_states: List[Tuple] = [
            (metric, {}) for metric in self.training_metrics
        ]
        self._last_train_log_step = 0

        log_prefix = f"{self.tb_log_prefix} " if self.tb_log_prefix else ""
        log_prefix += 'Training'

        with torch.enable_grad():
            for i in range(self.iter_per_step):

                # Zero the gradients and clear the accumulated loss
                self.optimizer.zero_grad()
                accumulated_loss = 0.0
                for _ in range(self.batches_per_iter):

                    # Get next batch
                    try:
                        batch = next(self._train_iterator)
                    except StopIteration:
                        self._create_train_iterator()
                        batch = next(self._train_iterator)

                    if self.device_count == 1:
                        batch = self._batch_to_device(batch)

                    _, _, loss = self._compute_batch(batch,
                                                     metrics_with_states)
                    accumulated_loss += loss.item() / self.batches_per_iter

                    loss.backward()

                # Log loss
                global_step = (self.iter_per_step * self._step) + i

                # Clip gradients if necessary
                if self.max_grad_norm:
                    clip_grad_norm_(self.model.parameters(),
                                    self.max_grad_norm)
                if self.max_grad_abs_val:
                    clip_grad_value_(self.model.parameters(),
                                     self.max_grad_abs_val)

                log(f'{log_prefix}/Loss', accumulated_loss, global_step)
                if self.device_count > 1:
                    log(f'{log_prefix}/Gradient_Norm',
                        self.model.module.gradient_norm, global_step)
                    log(f'{log_prefix}/Parameter_Norm',
                        self.model.module.parameter_norm, global_step)
                else:
                    log(f'{log_prefix}/Gradient_Norm',
                        self.model.gradient_norm, global_step)
                    log(f'{log_prefix}/Parameter_Norm',
                        self.model.parameter_norm, global_step)

                # Optimize
                self.optimizer.step()

                # Update iter scheduler
                if self.iter_scheduler is not None:
                    lr = self.optimizer.param_groups[0]['lr']  # type: ignore
                    log(f'{log_prefix}/LR', lr, global_step)
                    self.iter_scheduler.step()  # type: ignore

                # Zero the gradients when exiting a train step
                self.optimizer.zero_grad()
                # logging train metrics
                if self.extra_training_metrics_log_interval > self._last_train_log_step:
                    self._log_metrics(log_prefix, metrics_with_states,
                                      global_step)
                    self._last_train_log_step = i

            if self._last_train_log_step != i:
                # log again at end of step, if not logged at the end of
                # step before
                self._log_metrics(log_prefix, metrics_with_states, global_step)
Exemple #24
0
    def step(self, closure=None):
        """Performs a single optimization step.

        Args:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            params_with_grad = []
            grads = []
            exp_avgs = []
            exp_avg_sqs = []
            mu_products = []
            state_steps = []
            beta1, beta2 = group['betas']

            for p in group['params']:
                if p.grad is not None:
                    params_with_grad.append(p)
                    if p.grad.is_sparse:
                        raise RuntimeError(
                            'NAdam does not support sparse gradients')
                    grads.append(p.grad)

                    state = self.state[p]
                    # Lazy state initialization
                    if len(state) == 0:
                        state['step'] = torch.tensor(0.)
                        state['mu_product'] = torch.tensor(1.)
                        # Exponential moving average of gradient values
                        state['exp_avg'] = torch.zeros_like(
                            p, memory_format=torch.preserve_format)
                        # Exponential moving average of squared gradient values
                        state['exp_avg_sq'] = torch.zeros_like(
                            p, memory_format=torch.preserve_format)

                    exp_avgs.append(state['exp_avg'])
                    exp_avg_sqs.append(state['exp_avg_sq'])
                    mu_products.append(state['mu_product'])
                    state_steps.append(state['step'])

            nadam(params_with_grad,
                  grads,
                  exp_avgs,
                  exp_avg_sqs,
                  mu_products,
                  state_steps,
                  beta1=beta1,
                  beta2=beta2,
                  lr=group['lr'],
                  weight_decay=group['weight_decay'],
                  momentum_decay=group['momentum_decay'],
                  eps=group['eps'],
                  foreach=group['foreach'])

        return loss
Exemple #25
0
def evaluate(experiment_directory, checkpoint_path):

    chamfer_results = []

    specs = ws.load_experiment_specifications(experiment_directory)
    logging.info("Experiment description: \n" + specs["Description"])

    data_source = specs["DataSource"]
    test_split_file = specs["TestSplit"]

    num_samp_per_scene = specs["SamplesPerScene"]
    scene_per_batch = specs["ScenesPerBatch"]
    clamp_dist = specs["ClampingDistance"]
    minT = -clamp_dist
    maxT = clamp_dist
    enforce_minmax = False
    num_data_loader_threads =1
    # scene_per_subbatch =1
    batch_split = 1
    scene_per_subbatch = scene_per_batch // batch_split


    checkpoints = list(
        range(
            specs["SnapshotFrequency"],
            specs["NumEpochs"] + 1,
            specs["SnapshotFrequency"],
        )
    )

    def signal_handler(sig, frame):
        logging.info("Stopping early...")
        sys.exit(0)
    
    with open(test_split_file,"r") as f:
        test_split = json.load(f)

    sdf_dataset = dt_vtk.SDFVTKSamples(
        data_source, test_split, num_samp_per_scene
    )

    sdf_loader = data_utils.DataLoader(
        sdf_dataset,
        batch_size=scene_per_subbatch,
        shuffle=True,
        num_workers=num_data_loader_threads,
        drop_last=True,
    )


    decoder_eval = decoder.Decoder(0, **specs["NetworkSpecs"]).cuda()

    # for epoch in range(start_epoch, num_epochs + 1):

    #     start = time.time()

    #     logging.info("epoch {}...".format(epoch))
    # pdb.set_trace()
    checkpoint = torch.load(checkpoint_path)
    decoder_eval.load_state_dict(checkpoint['model_state_dict'])
    decoder_eval = decoder_eval.float()
    # decoder_eval.eval()
    for param in decoder_eval.parameters():
        param.requires_grad = False
    loss_l1 = torch.nn.L1Loss()
    loss_l2 = torch.nn.MSELoss()
    loss_log =[]
    # theta_x = torch.randn(1, requires_grad=True, dtype=torch.float)*3.1415
    # theta_y = torch.randn(1, requires_grad=True, dtype=torch.float)*3.1415
    # theta_z = torch.randn(1, requires_grad=True, dtype=torch.float)*3.1415
    # theta_x = theta_x.float()
    # theta_y = theta_y.float()
    # theta_z = theta_z.float()
    # theta_x.retain_grad()
    # theta_y.retain_grad()
    # theta_z.retain_grad()
    scale_one = torch.randn(1, requires_grad=True, dtype=torch.float)
    scale_two = torch.randn(1, requires_grad=True, dtype=torch.float)
    scale_three = torch.randn(1, requires_grad=True, dtype=torch.float)
    scale_one.retain_grad()
    scale_two.retain_grad()
    scale_three.retain_grad()
    # transform_matrix = torch.zeros(3,3).float().cuda()
    # transform_matrix.requires_grad_(True)
    # transform_matrix.retain_grad()
    transform_inpt = torch.randn(3,3).float().cuda()
    transform_inpt.requires_grad_(True)
    transform_inpt.retain_grad()
    bias = torch.zeros(3).float().cuda()
    bias.requires_grad_(True)
    bias.retain_grad()
    # pdb.set_trace()
    test_model = np.array(pd.read_csv("../chairs_segdata/points/1a8bbf2994788e2743e99e0cae970928.pts", header=None,sep=" ").values)
    num_epochs = 500
    learning_rate = 1e-3
    test_pts = torch.from_numpy(test_model).float()
    test_pts.requies_grad = False
    bt_size = 32
    num_batches = int(test_model.shape[0]//bt_size)
    sub = torch.Tensor([1]).cuda()
    reg = 1
    # rot_x = torch.from_numpy(rotate_mat_x()).double().cuda()
    # rot_x.requires_grad_(True)
    # rot_y = torch.from_numpy(rotate_mat_y()).double().cuda()
    # rot_y.requires_grad_(True)
    # rot_z = torch.from_numpy(rotate_mat_z()).double().cuda()
    # rot_z.requires_grad_(True)
    with torch.enable_grad():
        for j in range(num_epochs):
            # pdb.set_trace()
            # Process the input datag
            # sdf_data.requires_grad = False

            # sdf_data = (sdf_data.cuda()).reshape(
            #     num_samp_per_scene * scene_per_subbatch, 4
            # )
            
            # xyz = sdf_data[:, 0:3]
            # transform_matrix_update = torch.add(transform_matrix,bias)
            batch_loss=0
            for i in range(num_batches):
                test_torch = test_pts[i*bt_size:(i+1)*bt_size,:]
                # pdb.set_trace()
                # cosval_x = torch.cos(theta_x)
                # sinval_x = torch.sin(theta_x)
                # cosval_x.requires_grad_(True)
                # sinval_x.requires_grad_(True)
                # cosval_x.retain_grad()
                # sinval_x.retain_grad()
                # rot_x = torch.stack([torch.Tensor([1, 0, 0]),
                #             torch.cat([torch.Tensor([0]), cosval_x, -sinval_x]),
                #             torch.cat([torch.Tensor([0]), sinval_x, cosval_x])], dim=1).float().cuda()
                # rot_x.requires_grad_(True)
                # rot_x.retain_grad()
                # cosval_y = torch.cos(theta_y)
                # sinval_y = torch.sin(theta_y)
                # cosval_y.requires_grad_(True)
                # sinval_y.requires_grad_(True)
                # cosval_y.retain_grad()
                # sinval_y.retain_grad()
                # rot_y = torch.stack([torch.cat([cosval_y, torch.Tensor([0]), sinval_y]),
                #                     torch.Tensor([0, 1, 0]),
                #                     torch.cat([-sinval_y, torch.Tensor([0]), cosval_y])],dim=1).float().cuda()
                # rot_y.requires_grad_(True)
                # rot_y.retain_grad()
                # cosval_z = torch.cos(theta_z)
                # sinval_z = torch.sin(theta_z)
                # cosval_z.requires_grad_(True)
                # sinval_z.requires_grad_(True)
                # cosval_z.retain_grad()
                # sinval_z.retain_grad()
                # rot_z = torch.stack([torch.cat([cosval_z, -sinval_z, torch.Tensor([0])]),
                #                     torch.cat([sinval_z, cosval_z, torch.Tensor([0])]),
                #                     torch.Tensor([0, 0, 1])], dim=1).float().cuda()
                # rot_z.requires_grad_(True)
                # rot_z.retain_grad()
                scale_matrix = torch.cat([torch.cat([scale_one,torch.Tensor([0]),torch.Tensor([0])]),
                                          torch.cat([torch.Tensor([0]),scale_two,torch.Tensor([0])]),
                                          torch.cat([torch.Tensor([0]),torch.Tensor([0]),scale_three])]).view(3,3).float().cuda()
                # pdb.set_trace()
                scale_matrix.retain_grad()
                scale_matrix.requires_grad_(True)
                # transform_matrix = torch.matmul(torch.matmul(torch.matmul(rot_z,rot_y),rot_x),scale_matrix)
                transform_matrix = torch.matmul(transform_inpt, scale_matrix)
                transform_matrix.requires_grad_(True)
                transform_matrix.retain_grad()
                xyz = test_torch.cuda()
                xyz_transform = torch.matmul(xyz, transform_matrix)
                xyz_transform.requires_grad_(True)
                xyz_transform.retain_grad()
                transform_bias = torch.add(xyz_transform, bias).float()
                transform_bias.retain_grad()
                # diag_sum = torch.abs(torch.sum(torch.diag(transform_matrix)))
                # sdf_gt = sdf_data[:, 3].unsqueeze(1)
                pred_sdf = decoder_eval(transform_bias)
                # pred_sdf = decoder_eval(xyz_transform)
                # loss = loss_l1(pred_sdf, sdf_gt)
                target = torch.zeros(pred_sdf.shape[0],pred_sdf.shape[1]).float().cuda()
                # batch_loss += loss.item()
                # pdb.set_trace()
                diag_sum = torch.norm(torch.sub(torch.diag(scale_matrix),sub),2)
                diag_sum.retain_grad()
                diag_sum.requires_grad_(True)
                # diag_sum = torch.sum(torch.diag(transform_matrix)).cpu()
                loss1 = loss_l1(pred_sdf,target)
                loss2 = reg *diag_sum
                # loss2 = torch.abs(torch.sub(diag_sum,1))
                loss = torch.add(loss1,loss2)
                loss.backward(retain_graph=True)
                batch_loss+= loss.item()
                print('Batch Loss {:6.4f}'.format(loss.item()))
                with torch.no_grad():
                    # theta_z.data.sub_(theta_z.grad.data*learning_rate)
                    # theta_y.data.sub_(theta_y.data*learning_rate)
                    # theta_x.data.sub_(theta_x.grad.data*learning_rate)
                    bias.data.sub_(bias.grad.data*learning_rate)
                    scale_one.data.sub_(scale_one.grad.data*learning_rate)
                    scale_two.data.sub_(scale_two.grad.data*learning_rate)
                    scale_three.data.sub_(scale_three.grad.data*learning_rate)
                    transform_inpt.data.sub_(transform_inpt.grad.data*learning_rate)
                    # theta_z.grad.data.zero_()
                    # theta_y.grad.data.zero_()
                    # theta_x.grad.data.zero_()
                    bias.grad.data.zero_()
                    scale_one.grad.data.zero_()
                    scale_three.grad.data.zero_()
                    scale_two.grad.data.zero_()
                    scale_matrix.grad.data.zero_()
                    transform_bias.grad.data.zero_()
                    xyz_transform.grad.data.zero_()
                    transform_matrix.grad.data.zero_()
                    transform_inpt.grad.data.zero_()
                    diag_sum.grad.data.zero_()
                    # rot_z.grad.data.zero_()
                    # rot_x.grad.data.zero_()
                    # rot_y.grad.data.zero_()
            # pdb.set_trace()
            actual_loss = (batch_loss*bt_size)/(test_model.shape[0])
            # print("Loss after {} epoch is {:6.4f}".format(j,batch_loss))
            print("Loss after {} epoch is {:6.4f}".format(j,actual_loss))
            loss_log.append(actual_loss)
    pdb.set_trace()
    fig,ax = plt.subplots()
    ax.plot(np.arange(num_epochs),loss_log)
    ax.set(xlabel='iterations',ylabel='transformationloss')
    plt.savefig('Transformation_loss_new.png')
    torch.save(transform_matrix,'transform_matrix_new.pt')
    torch.save(bias,'bias_new.pt')
    test_pts = torch.from_numpy(pd.read_csv('test_model.pts',header=None, sep=' ').values).cuda()
    transform_pts = torch.matmul(test_pts, transform_matrix.double())
    transform_pts = torch.add(transform_pts, bias.double()).cpu().detach().numpy()
    np.savetxt('transform_points_new.pts',transform_pts)
    plot_heatmap(experiment_directory, checkpoint_path)
Exemple #26
0
    def _propagate_compression(self):
        """Propagate the compression to other layers for max sparsity.

        Here we apply a simple heuristic that is true for any kind of pruning:
        * If the gradient of some parameters is consistently 0 across multiple
          batches of data, we can safely prune that parameter as well.
        * e.g.: a bias parameter of a channel that gets never used.

        Depending on FilterNet or WeightNet, we might apply additional
        heuristics to propagate the compression.
        """
        def _zero_grad(net):
            """Set gradients of all parameters back to None."""
            for param in net.parameters():
                param.grad = None

        # zero-out gradients
        _zero_grad(self.compressed_net)

        # get the device
        device = self.compressed_net.compressible_layers[0].weight.device

        # make sure we are in eval mode (avoid updates to BN, etc...)
        is_training_mode = self.compressed_net.training
        self.compressed_net.eval()

        # do a couple of forward+backward passes
        at_least_one_batch = False
        with torch.enable_grad():
            for images, targets in self._loader_s:
                if len(images) < 2:
                    continue
                at_least_one_batch = True
                images = tensor.to(images, device, non_blocking=True)
                targets = tensor.to(targets, device, non_blocking=True)
                outs = self.compressed_net(images)
                loss = self._loss_handle(outs, targets)
                loss.backward()
        assert at_least_one_batch, "No batch with more than one data point!"

        # post-process gradients to set respective weights to zero
        some_grad_none = False
        with torch.no_grad():
            for param in self._parameters_for_grad_prune():
                grad = param.grad
                if grad is None:
                    some_grad_none = True
                    continue

                # mask anything at machine precision or below.
                prune_mask = self._get_prune_mask_from_grad(grad)
                param.masked_fill_(prune_mask, 0.0)

        # issue warning in case some gradients were None
        if some_grad_none:
            warnings.warn("Some parameters did not received gradients"
                          " while propagating compression!")

        # zero-out gradients one more time at the end
        _zero_grad(self.compressed_net)

        # revert back to training mode if it was in training mode before
        self.compressed_net.train(is_training_mode)
Exemple #27
0
def trades_loss(model,
                x_natural,
                y,
                optimizer,
                step_size=0.003,
                epsilon=0.031,
                perturb_steps=10,
                beta=1.0,
                distance='l_inf'):
    # define KL-loss
    criterion_kl = nn.KLDivLoss(size_average=False)
    model.eval()
    batch_size = len(x_natural)
    # generate adversarial example
    x_adv = x_natural.detach() + 0.001 * torch.randn(
        x_natural.shape).cuda().detach()
    if distance == 'l_inf':
        for _ in range(perturb_steps):
            x_adv.requires_grad_()
            with torch.enable_grad():
                features, logits = model(x_adv)
                # print("features {} logits {}".format(features.size(), logits.size()))
                loss_kl = F.cross_entropy(logits, y)
            grad = torch.autograd.grad(loss_kl, [x_adv])[0]
            x_adv = x_adv.detach() + step_size * torch.sign(grad.detach())
            x_adv = torch.min(torch.max(x_adv, x_natural - epsilon),
                              x_natural + epsilon)
            x_adv = torch.clamp(x_adv, 0.0, 1.0)
    elif distance == 'l_2':
        for _ in range(perturb_steps):
            x_adv.requires_grad_()
            with torch.enable_grad():
                loss_kl = criterion_kl(F.log_softmax(model(x_adv), dim=1),
                                       F.softmax(model(x_natural), dim=1))
            grad = torch.autograd.grad(loss_kl, [x_adv])[0]
            for idx_batch in range(batch_size):
                grad_idx = grad[idx_batch]
                grad_idx_norm = l2_norm(grad_idx)
                grad_idx /= (grad_idx_norm + 1e-8)
                x_adv[idx_batch] = x_adv[idx_batch].detach(
                ) + step_size * grad_idx
                eta_x_adv = x_adv[idx_batch] - x_natural[idx_batch]
                norm_eta = l2_norm(eta_x_adv)
                if norm_eta > epsilon:
                    eta_x_adv = eta_x_adv * epsilon / l2_norm(eta_x_adv)
                x_adv[idx_batch] = x_natural[idx_batch] + eta_x_adv
            x_adv = torch.clamp(x_adv, 0.0, 1.0)
    else:
        x_adv = torch.clamp(x_adv, 0.0, 1.0)
    model.train()

    x_adv = Variable(torch.clamp(x_adv, 0.0, 1.0), requires_grad=False)
    # zero gradient
    optimizer.zero_grad()
    # calculate robust loss
    # logits = model(x_natural)
    # loss_natural = F.cross_entropy(logits, y)
    # loss_robust = (1.0 / batch_size) * criterion_kl(F.log_softmax(model(x_adv), dim=1),
    #                                                 F.softmax(model(x_natural), dim=1))
    # loss = loss_natural + beta * loss_robust
    _, logits = model(x_adv)
    loss = F.cross_entropy(logits, y)
    return loss
Exemple #28
0
    def step(self, closure=None):
        """Performs a single optimization step.
        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue

                # Perform optimization step
                grad = p.grad
                if grad.is_sparse:
                    raise RuntimeError(
                        'AdamW does not support sparse gradients')
                amsgrad = group['amsgrad']

                state = self.state[p]

                # State initialization
                if len(state) == 0:
                    state['step'] = 0
                    # Exponential moving average of gradient values
                    state['exp_avg'] = torch.zeros_like(p)
                    # Exponential moving average of squared gradient values
                    state['exp_avg_sq'] = torch.zeros_like(p)
                    if amsgrad:
                        # Maintains max of all exp. moving avg. of sq. grad. values
                        state['max_exp_avg_sq'] = torch.zeros_like(p)

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                if amsgrad:
                    max_exp_avg_sq = state['max_exp_avg_sq']
                beta1, beta2 = group['betas']

                state['step'] += 1
                bias_correction1 = 1 - beta1**state['step']
                bias_correction2 = 1 - beta2**state['step']
                if self.gc_loc:
                    grad = centralized_gradient(grad,
                                                use_gc=self.use_gc,
                                                gc_conv_only=self.gc_conv_only)

                # Decay the first and second moment running average coefficient
                exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
                if amsgrad:
                    # Maintains the maximum of all 2nd moment running avg. till now
                    torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
                    # Use the max. for normalizing running avg. of gradient
                    denom = (max_exp_avg_sq.sqrt() /
                             math.sqrt(bias_correction2)).add_(group['eps'])
                else:
                    denom = (exp_avg_sq.sqrt() /
                             math.sqrt(bias_correction2)).add_(group['eps'])

                step_size = group['lr'] / bias_correction1

                #GC operation and stepweight decay
                G_grad = (exp_avg / denom).add(p.data,
                                               alpha=group['weight_decay'])
                if self.gc_loc == False:
                    G_grad = centralized_gradient(
                        G_grad,
                        use_gc=self.use_gc,
                        gc_conv_only=self.gc_conv_only)

                p.add_(G_grad, alpha=-step_size)

        return loss
Exemple #29
0
    def forward(ctx, x, y, xnew, out=None):
        """
        Linear 1D interpolation on the GPU for Pytorch.
        This function returns interpolated values of a set of 1-D functions at
        the desired query points `xnew`.
        This function is working similarly to Matlab™ or scipy functions with
        the `linear` interpolation mode on, except that it parallelises over
        any number of desired interpolation problems.
        The code will run on GPU if all the tensors provided are on a cuda
        device.
        Parameters
        ----------
        x : (N, ) or (D, N) Pytorch Tensor
            A 1-D or 2-D tensor of real values.
        y : (N,) or (D, N) Pytorch Tensor
            A 1-D or 2-D tensor of real values. The length of `y` along its
            last dimension must be the same as that of `x`
        xnew : (P,) or (D, P) Pytorch Tensor
            A 1-D or 2-D tensor of real values. `xnew` can only be 1-D if
            _both_ `x` and `y` are 1-D. Otherwise, its length along the first
            dimension must be the same as that of whichever `x` and `y` is 2-D.
        out : Pytorch Tensor, same shape as `xnew`
            Tensor for the output. If None: allocated automatically.
        """
        # making the vectors at least 2D
        is_flat = {}
        require_grad = {}
        v = {}
        device = []
        eps = torch.finfo(y.dtype).eps
        for name, vec in {"x": x, "y": y, "xnew": xnew}.items():
            assert len(vec.shape) <= 2, "interp1d: all inputs must be " "at most 2-D."
            if len(vec.shape) == 1:
                v[name] = vec[None, :]
            else:
                v[name] = vec
            is_flat[name] = v[name].shape[0] == 1
            require_grad[name] = vec.requires_grad
            device = list(set(device + [str(vec.device)]))
        assert len(device) == 1, "All parameters must be on the same device."
        device = device[0]

        # Checking for the dimensions
        assert v["x"].shape[1] == v["y"].shape[1] and (
            v["x"].shape[0] == v["y"].shape[0] or v["x"].shape[0] == 1 or v["y"].shape[0] == 1
        ), (
            "x and y must have the same number of columns, and either "
            "the same number of row or one of them having only one "
            "row."
        )

        reshaped_xnew = False
        if (v["x"].shape[0] == 1) and (v["y"].shape[0] == 1) and (v["xnew"].shape[0] > 1):
            # if there is only one row for both x and y, there is no need to
            # loop over the rows of xnew because they will all have to face the
            # same interpolation problem. We should just stack them together to
            # call interp1d and put them back in place afterwards.
            original_xnew_shape = v["xnew"].shape
            v["xnew"] = v["xnew"].contiguous().view(1, -1)
            reshaped_xnew = True

        # identify the dimensions of output and check if the one provided is ok
        D = max(v["x"].shape[0], v["xnew"].shape[0])
        shape_ynew = (D, v["xnew"].shape[-1])
        if out is not None:
            if out.numel() != shape_ynew[0] * shape_ynew[1]:
                # The output provided is of incorrect shape.
                # Going for a new one
                out = None
            else:
                ynew = out.reshape(shape_ynew)
        if out is None:
            ynew = torch.zeros(*shape_ynew, device=device)

        # moving everything to the desired device in case it was not there
        # already (not handling the case things do not fit entirely, user will
        # do it if required.)
        for name in v:
            v[name] = v[name].to(device)

        # calling searchsorted on the x values.
        ind = ynew.long()

        # expanding xnew to match the number of rows of x in case only one xnew is
        # provided
        if v["xnew"].shape[0] == 1:
            v["xnew"] = v["xnew"].expand(v["x"].shape[0], -1)

        torch.searchsorted(v["x"].contiguous(), v["xnew"].contiguous(), out=ind)

        # the `-1` is because searchsorted looks for the index where the values
        # must be inserted to preserve order. And we want the index of the
        # preceeding value.
        ind -= 1
        # we clamp the index, because the number of intervals is x.shape-1,
        # and the left neighbour should hence be at most number of intervals
        # -1, i.e. number of columns in x -2
        ind = torch.clamp(ind, 0, v["x"].shape[1] - 1 - 1)

        # helper function to select stuff according to the found indices.
        def sel(name):
            if is_flat[name]:
                return v[name].contiguous().view(-1)[ind]
            return torch.gather(v[name], 1, ind)

        # activating gradient storing for everything now
        enable_grad = False
        saved_inputs = []
        for name in ["x", "y", "xnew"]:
            if require_grad[name]:
                enable_grad = True
                saved_inputs += [v[name]]
            else:
                saved_inputs += [
                    None,
                ]
        # assuming x are sorted in the dimension 1, computing the slopes for
        # the segments
        is_flat["slopes"] = is_flat["x"]
        # now we have found the indices of the neighbors, we start building the
        # output. Hence, we start also activating gradient tracking
        with torch.enable_grad() if enable_grad else contextlib.suppress():
            v["slopes"] = (v["y"][:, 1:] - v["y"][:, :-1]) / (eps + (v["x"][:, 1:] - v["x"][:, :-1]))

            # now build the linear interpolation
            ynew = sel("y") + sel("slopes") * (v["xnew"] - sel("x"))

            if reshaped_xnew:
                ynew = ynew.view(original_xnew_shape)

        ctx.save_for_backward(ynew, *saved_inputs)
        return ynew
Exemple #30
0
    def train(self, training_batch: rlt.PolicyNetworkInput) -> None:
        """
        IMPORTANT: the input action here is assumed to match the
        range of the output of the actor.
        """

        assert isinstance(training_batch, rlt.PolicyNetworkInput)

        self.minibatch += 1

        state = training_batch.state
        action = training_batch.action
        reward = training_batch.reward
        discount = torch.full_like(reward, self.gamma)
        not_done_mask = training_batch.not_terminal

        # We need to zero out grad here because gradient from actor update
        # should not be used in Q-network update
        self.actor_network_optimizer.zero_grad()
        self.q1_network_optimizer.zero_grad()
        if self.q2_network is not None:
            self.q2_network_optimizer.zero_grad()
        if self.value_network is not None:
            self.value_network_optimizer.zero_grad()

        with torch.enable_grad():
            #
            # First, optimize Q networks; minimizing MSE between
            # Q(s, a) & r + discount * V'(next_s)
            #

            q1_value = self.q1_network(state, action)
            if self.q2_network:
                q2_value = self.q2_network(state, action)
            actor_output = self.actor_network(state)

            # Optimize Alpha
            if self.alpha_optimizer is not None:
                alpha_loss = -(
                    (
                        self.log_alpha
                        * (actor_output.log_prob + self.target_entropy).detach()
                    ).mean()
                )
                self.alpha_optimizer.zero_grad()
                alpha_loss.backward()
                self.alpha_optimizer.step()
                self.entropy_temperature = self.log_alpha.exp()

            with torch.no_grad():
                if self.value_network is not None:
                    next_state_value = self.value_network_target(
                        training_batch.next_state.float_features
                    )
                else:
                    next_state_actor_output = self.actor_network(
                        training_batch.next_state
                    )
                    next_state_actor_action = (
                        training_batch.next_state,
                        rlt.FeatureData(next_state_actor_output.action),
                    )
                    next_state_value = self.q1_network_target(*next_state_actor_action)

                    if self.q2_network is not None:
                        target_q2_value = self.q2_network_target(
                            *next_state_actor_action
                        )
                        next_state_value = torch.min(next_state_value, target_q2_value)

                    log_prob_a = self.actor_network.get_log_prob(
                        training_batch.next_state, next_state_actor_output.action
                    )
                    log_prob_a = log_prob_a.clamp(-20.0, 20.0)
                    next_state_value -= self.entropy_temperature * log_prob_a

                if self.gamma > 0.0:
                    target_q_value = (
                        reward + discount * next_state_value * not_done_mask.float()
                    )
                else:
                    # This is useful in debugging instability issues
                    target_q_value = reward

            q1_loss = F.mse_loss(q1_value, target_q_value)
            q1_loss.backward()
            self._maybe_run_optimizer(
                self.q1_network_optimizer, self.minibatches_per_step
            )
            if self.q2_network:
                q2_loss = F.mse_loss(q2_value, target_q_value)
                q2_loss.backward()
                self._maybe_run_optimizer(
                    self.q2_network_optimizer, self.minibatches_per_step
                )

            # Second, optimize the actor; minimizing KL-divergence between
            # propensity & softmax of value.  Due to reparameterization trick,
            # it ends up being log_prob(actor_action) - Q(s, actor_action)

            state_actor_action = (state, rlt.FeatureData(actor_output.action))
            q1_actor_value = self.q1_network(*state_actor_action)
            min_q_actor_value = q1_actor_value
            if self.q2_network:
                q2_actor_value = self.q2_network(*state_actor_action)
                min_q_actor_value = torch.min(q1_actor_value, q2_actor_value)

            actor_loss = (
                self.entropy_temperature * actor_output.log_prob - min_q_actor_value
            )
            # Do this in 2 steps so we can log histogram of actor loss
            # pyre-fixme[16]: `float` has no attribute `mean`.
            actor_loss_mean = actor_loss.mean()

            if self.add_kld_to_loss:
                if self.apply_kld_on_mean:
                    action_batch_m = torch.mean(actor_output.squashed_mean, axis=0)
                    action_batch_v = torch.var(actor_output.squashed_mean, axis=0)
                else:
                    action_batch_m = torch.mean(actor_output.action, axis=0)
                    action_batch_v = torch.var(actor_output.action, axis=0)
                kld = (
                    0.5
                    * (
                        (action_batch_v + (action_batch_m - self.action_emb_mean) ** 2)
                        / self.action_emb_variance
                        - 1
                        + self.action_emb_variance.log()
                        - action_batch_v.log()
                    ).sum()
                )

                actor_loss_mean += self.kld_weight * kld

            actor_loss_mean.backward()
            self._maybe_run_optimizer(
                self.actor_network_optimizer, self.minibatches_per_step
            )

            #
            # Lastly, if applicable, optimize value network; minimizing MSE between
            # V(s) & E_a~pi(s) [ Q(s,a) - log(pi(a|s)) ]
            #

            if self.value_network is not None:
                state_value = self.value_network(state.float_features)

                if self.logged_action_uniform_prior:
                    log_prob_a = torch.zeros_like(min_q_actor_value)
                    target_value = min_q_actor_value
                else:
                    with torch.no_grad():
                        log_prob_a = actor_output.log_prob
                        log_prob_a = log_prob_a.clamp(-20.0, 20.0)
                        target_value = (
                            min_q_actor_value - self.entropy_temperature * log_prob_a
                        )

                value_loss = F.mse_loss(state_value, target_value.detach())
                value_loss.backward()
                self._maybe_run_optimizer(
                    self.value_network_optimizer, self.minibatches_per_step
                )

        # Use the soft update rule to update the target networks
        if self.value_network is not None:
            self._maybe_soft_update(
                self.value_network,
                self.value_network_target,
                self.tau,
                self.minibatches_per_step,
            )
        else:
            self._maybe_soft_update(
                self.q1_network,
                self.q1_network_target,
                self.tau,
                self.minibatches_per_step,
            )
            if self.q2_network is not None:
                self._maybe_soft_update(
                    self.q2_network,
                    self.q2_network_target,
                    self.tau,
                    self.minibatches_per_step,
                )

        # Logging at the end to schedule all the cuda operations first
        if (
            self.tensorboard_logging_freq != 0
            and self.minibatch % self.tensorboard_logging_freq == 0
        ):
            SummaryWriterContext.add_histogram("q1/logged_state_value", q1_value)
            if self.q2_network:
                SummaryWriterContext.add_histogram("q2/logged_state_value", q2_value)

            # pyre-fixme[16]: `SummaryWriterContext` has no attribute `add_scalar`.
            SummaryWriterContext.add_scalar(
                "entropy_temperature", self.entropy_temperature
            )
            SummaryWriterContext.add_histogram("log_prob_a", log_prob_a)
            if self.value_network:
                SummaryWriterContext.add_histogram("value_network/target", target_value)

            SummaryWriterContext.add_histogram(
                "q_network/next_state_value", next_state_value
            )
            SummaryWriterContext.add_histogram(
                "q_network/target_q_value", target_q_value
            )
            SummaryWriterContext.add_histogram(
                "actor/min_q_actor_value", min_q_actor_value
            )
            SummaryWriterContext.add_histogram(
                "actor/action_log_prob", actor_output.log_prob
            )
            SummaryWriterContext.add_histogram("actor/loss", actor_loss)
            if self.add_kld_to_loss:
                SummaryWriterContext.add_histogram("kld/mean", action_batch_m)
                SummaryWriterContext.add_histogram("kld/var", action_batch_v)
                SummaryWriterContext.add_scalar("kld/kld", kld)

        self.loss_reporter.report(
            td_loss=float(q1_loss),
            reward_loss=None,
            logged_rewards=reward,
            model_values_on_logged_actions=q1_value,
            model_propensities=actor_output.log_prob.exp(),
            model_values=min_q_actor_value,
        )
Exemple #31
0
    def step(self, closure: Optional[Callable] = None) -> Optional[float]:
        """Performs a single optimization step.

        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            for p in group[AdamWOptimKeys.PARAMS]:
                if p.grad is None:
                    continue

                # Perform optimization step.
                grad = p.grad.data
                if grad.is_sparse:
                    raise RuntimeError(
                        'AdamW does not support sparse gradients.')

                state: Dict = self.state[p]

                # State initialization.
                if len(state) == 0:
                    # Global training step.
                    state[AdamWOptimKeys.STEP] = 0
                    # Exponential moving average of gradient values.
                    state[AdamWOptimKeys.EMA_GRADIENT] = torch.zeros_like(
                        p.data)
                    # Exponential moving average of squared gradient values
                    state[AdamWOptimKeys.
                          EMA_SQUARED_GRADIENT] = torch.zeros_like(p.data)

                exp_avg: torch.Tensor = state[AdamWOptimKeys.EMA_GRADIENT]
                exp_avg_sq: torch.Tensor = state[
                    AdamWOptimKeys.EMA_SQUARED_GRADIENT]
                beta1, beta2 = group[AdamWOptimKeys.BETAS]

                state[AdamWOptimKeys.STEP] += 1

                # Decay the first and second moment running average coefficient.
                # In-place operations to update the averages at the same time.
                exp_avg.mul_(other=beta1).add_(other=grad, alpha=1 - beta1)
                exp_avg_sq.mul_(other=beta2).addcmul_(tensor1=grad,
                                                      tensor2=grad,
                                                      value=1 - beta2)
                denom = exp_avg_sq.sqrt().add_(other=group[AdamWOptimKeys.EPS])

                step_size = group[AdamWOptimKeys.LR]
                if group[AdamWOptimKeys.CORRECT_BIAS] is True:
                    # NOTE: Don't perform bias correction for BERT model.
                    bias_correction1 = 1 - beta1**state[AdamWOptimKeys.STEP]
                    bias_correction2 = 1 - beta2**state[AdamWOptimKeys.STEP]
                    step_size = step_size * math.sqrt(
                        bias_correction2) / bias_correction1

                p.data.addcdiv_(tensor1=exp_avg,
                                tensor2=denom,
                                value=-step_size)

                # Just adding the square of the weights to the loss function is *not*
                # the correct way of using L2 regularization/weight decay with Adam,
                # since that will interact with the m and v parameters in strange ways.
                #
                # Instead we want to decay the weights in a manner that doesn't interact
                # with the m/v parameters. This is equivalent to adding the square
                # of the weights to the loss with plain (non-momentum) SGD.
                # Add weight decay at the end (fixed version).
                if group[AdamWOptimKeys.WEIGHT_DECAY] > 0.0:
                    p.data.add_(other=p.data,
                                alpha=-group[AdamWOptimKeys.LR] *
                                group[AdamWOptimKeys.WEIGHT_DECAY])

        return loss
Exemple #32
0
    def step(self, closure=None):
        """Performs a single optimization step.

        Args:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            params_with_grad = []
            grads = []
            exp_avgs = []
            exp_avg_sqs = []
            state_sums = []
            max_exp_avg_sqs = []
            state_steps = []
            amsgrad = group['amsgrad']

            for p in group['params']:
                if p.grad is None:
                    continue
                params_with_grad.append(p)
                if p.grad.is_sparse:
                    raise RuntimeError(
                        'AdamW does not support sparse gradients')
                grads.append(p.grad)

                state = self.state[p]

                # State initialization
                if len(state) == 0:
                    state['step'] = 0
                    # Exponential moving average of gradient values
                    state['exp_avg'] = torch.zeros_like(
                        p, memory_format=torch.preserve_format)
                    # Exponential moving average of squared gradient values
                    state['exp_avg_sq'] = torch.zeros_like(
                        p, memory_format=torch.preserve_format)
                    if amsgrad:
                        # Maintains max of all exp. moving avg. of sq. grad. values
                        state['max_exp_avg_sq'] = torch.zeros_like(
                            p, memory_format=torch.preserve_format)

                exp_avgs.append(state['exp_avg'])
                exp_avg_sqs.append(state['exp_avg_sq'])

                if amsgrad:
                    max_exp_avg_sqs.append(state['max_exp_avg_sq'])

                beta1, beta2 = group['betas']
                # update the steps for each param group update
                state['step'] += 1
                # record the step after step update
                state_steps.append(state['step'])

            F.adamw(params_with_grad, grads, exp_avgs, exp_avg_sqs,
                    max_exp_avg_sqs, state_steps, amsgrad, beta1, beta2,
                    group['lr'], group['weight_decay'], group['eps'])

        return loss
    def step(self, closure=None):
        """Performs a single optimization step.

        Args:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                grad = p.grad
                if grad.is_sparse:
                    raise RuntimeError(
                        'Adamax does not support sparse gradients')
                state = self.state[p]

                # State initialization
                if len(state) == 0:
                    state['step'] = 0
                    state['exp_avg'] = torch.zeros_like(
                        p, memory_format=torch.preserve_format)
                    state['exp_k'] = torch.zeros_like(
                        p, memory_format=torch.preserve_format)

                exp_avg, exp_k = state['exp_avg'], state['exp_k']
                beta1, beta2 = group['betas']
                eps = group['eps']
                k = group['k']
                lr = group['lr']

                state['step'] += 1

                # Weight Decay
                p.mul_(1 - lr * group['weight_decay'])

                #Momentum Decay (Demon)
                temp_i = 1 - (state['step'] / self.T)
                beta1 = beta1 * temp_i / ((1 - beta1) + beta1 * temp_i)

                # Update biased first moment estimate.
                exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)

                # Update biased k-moment estimate
                exp_k.mul_(beta2).add_(grad**k, alpha=1 - beta2)
                # Update the exponentially weighted infinity norm.
                #norm_buf = torch.cat([
                #    exp_inf.mul_(beta2).unsqueeze(0),
                #    grad.abs().add_(eps).unsqueeze_(0)
                #], 0)
                #torch.amax(norm_buf, 0, keepdim=False, out=exp_inf)

                bias_correction_1 = 1 - beta1**state['step']
                bias_correction_2 = (1 - beta2**state['step'])**(1 / k)

                clr = group['lr'] * bias_correction_2 / bias_correction_1

                p.addcdiv_(exp_avg, exp_k.pow(1 / k) + eps, value=-clr)

        return loss