Ejemplo n.º 1
0
 def copy(self, tensor, only_mutable=False):
     return torch.clone(tensor)
    def forward(self, inputs, core_state=(), mems=None, mem_padding=None):

        x = inputs["frame"]
        T, B, *_ = x.shape
        x = torch.flatten(x, 0, 1)  # Merge time and batch.
        x = x.float() / 255.0

        for i, fconv in enumerate(self.feat_convs):
            x = fconv(x)
            res_input = x
            x = self.resnet1[i](x)
            x += res_input
            res_input = x
            x = self.resnet2[i](x)
            x += res_input
        x = F.relu(x)
        x = x.view(T * B, -1)
        x = F.relu(self.fc(x))

        #print('inputs: ', inputs)
        #print('inputs last action', inputs['last_action'])
        one_hot_last_action = F.one_hot(inputs["last_action"].view(T * B),
                                        self.num_actions).float()

        clipped_reward = torch.clamp(inputs["reward"], -1, 1).view(T * B, 1)
        core_input = torch.cat([x, clipped_reward, one_hot_last_action],
                               dim=-1)

        core_input = core_input.view(T, B, -1)

        padding_mask = torch.clone(inputs['done'])

        ind_first_done = None
        if padding_mask.dim(
        ) > 1:  #This only seems to not happen on first state ever in env.initialize()
            # this block just tries to push the dones one position down so that the loss calculation does account
            # for that step and not ignores it as mask
            ind_first_done = padding_mask.long().argmin(
                0) + 1  # will be index of first 1 in each column
            orig_first_row = torch.clone(padding_mask[0, :])
            ind_first_done[
                padding_mask[0, :] ==
                1] = 0  # If there aren't any 0's in the whole inputs['done'] then set ind_first_done to 0
            ind_first_done[
                ind_first_done >= padding_mask.
                shape[0]] = -1  # choosing -1 helps in learn function
            padding_mask[ind_first_done, range(B)] = False
            padding_mask[0, :] = orig_first_row

        padding_mask = padding_mask.unsqueeze(0)
        if not padding_mask.any().item(
        ):  #In this case no need for padding_mask
            padding_mask = None

        core_output, mems = self.core(
            core_input,
            mems,
            padding_mask=padding_mask,
            mem_padding=mem_padding)  # core_input is of shape (T, B, ...)
        # core_output is (B, ...)

        policy_logits = self.policy(core_output)
        baseline = self.baseline(core_output)

        policy_logits = policy_logits.reshape(T * B, self.num_actions)
        # # if policy_logits.shape[0] == 32 and policy_logits.shape[1] == 6:
        # if not torch.all(policy_logits == policy_logits).item():
        #     # nans only come when the learner_model calls this forward
        #     print('from monobeast 921\n', policy_logits)
        #     print('core output : ',core_output.shape, '\n', core_output)
        #     print('core input : \n', core_input)
        #     print('mask : \n', padding_mask)
        #     print('mems : \n', mems)
        #     torch.save(core_input, './core_input.pt')
        #     torch.save(padding_mask, './padding_mask.pt')
        #     torch.save(mems, './mems.pt')

        if self.training:
            # Sample from multinomial distribution for exploration
            # if not (padding_mask is None) and padding_mask.shape[1] > 1:
            #     print('Padding shape: {}, logits shape: {}'.format(padding_mask.shape, policy_logits.shape))
            #     print('PADDING: ', padding_mask)
            #     print("LOGITS: ", policy_logits)
            action = torch.multinomial(F.softmax(policy_logits, dim=1),
                                       num_samples=1)
        else:
            # Don't sample when testing.
            action = torch.argmax(policy_logits, dim=1)

        policy_logits = policy_logits.view(T, B, self.num_actions)
        baseline = baseline.view(T, B)

        action = action.view(T, B)

        return (dict(policy_logits=policy_logits,
                     baseline=baseline,
                     action=action), core_state, mems, padding_mask,
                ind_first_done)
Ejemplo n.º 3
0
def train(data_dir, model_dir, args):
    seed_everything(args.seed)
    # args.__dict__ == vars(args)
    wandb.init(project="train_01", config=vars(args))

    save_dir = increment_path(os.path.join(model_dir, args.name))

    # -- settings
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")

    # -- dataset
    dataset_module = getattr(import_module("dataset"),
                             args.dataset)  # MaskBaseDataset
    dataset = dataset_module(data_dir=data_dir, val_ratio=args.val_ratio)
    num_classes = dataset.num_classes  # 18

    # -- augmentation
    transform_module = getattr(import_module("dataset"),
                               args.augmentation)  # default: BaseAugmentation
    transform = transform_module(
        resize=args.resize,
        mean=dataset.mean,
        std=dataset.std,
    )
    dataset.set_transform(transform)

    # -- data_loader
    train_set, val_set = dataset.split_dataset()

    train_loader = DataLoader(train_set,
                              batch_size=args.batch_size,
                              num_workers=8,
                              shuffle=True,
                              pin_memory=use_cuda,
                              drop_last=True)

    val_loader = DataLoader(val_set,
                            batch_size=args.batch_size,
                            num_workers=8,
                            shuffle=False,
                            pin_memory=use_cuda,
                            drop_last=True)

    # -- model
    model_module = getattr(import_module("model"),
                           args.model)  # default: BaseModel
    model = model_module(num_classes=num_classes,
                         grad_point=args.grad_point).to(device)
    model = torch.nn.DataParallel(model)
    # if want model train begin from args.continue_epoch checkpoint.
    if args.continue_train:
        try_dir = find_dir_try(args.continue_try_num, model_dir,
                               args.continue_name)
        epoch_dir = find_dir_epoch(args.continue_epoch, try_dir)
        model.load_state_dict(torch.load(epoch_dir))

    # -- loss & metric
    if args.criterion == "cross_entropy":
        criterion = create_criterion(args.criterion)  # default: cross_entropy
    else:
        criterion = create_criterion(
            args.criterion, classes=num_classes)  # default: cross_entropy
    if args.optimizer == "AdamP":
        optimizer = AdamP(filter(lambda p: p.requires_grad,
                                 model.parameters()),
                          lr=args.lr,
                          weight_decay=5e-4)
    else:
        opt_module = getattr(import_module('torch.optim'),
                             args.optimizer)  # default: Adam
        optimizer = opt_module(filter(lambda p: p.requires_grad,
                                      model.parameters()),
                               lr=args.lr,
                               weight_decay=5e-4)
    scheduler = StepLR(optimizer, args.lr_decay_step, gamma=0.5)

    # -- logging
    if not os.path.exists(save_dir):
        os.mkdir(save_dir)
    with open(Path(save_dir) / 'config.json', 'w', encoding='utf-8') as f:
        json.dump(vars(args), f, ensure_ascii=False, indent=4)

    best_val_acc = 0
    best_val_loss = np.inf
    for epoch in range(args.epochs):
        # train loop
        model.train()
        loss_value = 0
        matches = 0
        for idx, train_batch in enumerate(train_loader):
            inputs, labels = train_batch
            inputs = inputs.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()

            outs = model(inputs)
            preds = torch.argmax(outs, dim=-1)
            loss = criterion(outs, labels)

            loss.backward()
            optimizer.step()

            loss_value += loss.item()
            matches += (preds == labels).sum().item()
            if (idx + 1) % args.log_interval == 0:
                train_loss = loss_value / args.log_interval
                train_acc = matches / args.batch_size / args.log_interval
                current_lr = get_lr(optimizer)
                print(
                    f"Epoch[{epoch}/{args.epochs}]({idx + 1}/{len(train_loader)}) || "
                    f"training loss {train_loss:4.4} || training accuracy {train_acc:4.2%} || lr {current_lr}"
                )
                wandb.log({
                    "Train/loss": train_loss,
                    "Train/accuracy": train_acc
                })

                loss_value = 0
                matches = 0

        scheduler.step()

        #val loop
        with torch.no_grad():
            print("Calculating validation results...")
            model.eval()
            val_loss_items = []
            val_acc_items = []
            figure = None
            for val_batch in val_loader:
                inputs, labels = val_batch
                inputs = inputs.to(device)
                labels = labels.to(device)

                outs = model(inputs)
                preds = torch.argmax(outs, dim=-1)

                loss_item = criterion(outs, labels).item()
                acc_item = (labels == preds).sum().item()
                val_loss_items.append(loss_item)
                val_acc_items.append(acc_item)

                if figure is None:
                    # inputs_np = torch.clone(inputs).detach().cpu().permute(0, 2, 3, 1).numpy()
                    inputs_np = torch.clone(inputs).detach().cpu()
                    inputs_np = inputs_np.permute(0, 2, 3, 1).numpy()
                    inputs_np = dataset_module.denormalize_image(
                        inputs_np, dataset.mean, dataset.std)
                    figure = grid_image(
                        inputs_np, labels, preds,
                        args.dataset != "MaskSplitByProfileDataset")
                    plt.show()

            val_loss = np.sum(val_loss_items) / len(val_loader)
            val_acc = np.sum(val_acc_items) / len(val_set)
            if val_loss < best_val_loss or val_acc > best_val_acc:
                save_model(model, epoch, val_loss, val_acc, save_dir,
                           args.model)
                if val_loss < best_val_loss and val_acc > best_val_acc:
                    print(
                        f"New best model for val acc and val loss : {val_acc:4.2%} {val_loss:4.2}! saving the best model.."
                    )
                    best_val_loss = val_loss
                    best_val_acc = val_acc
                elif val_loss < best_val_loss:
                    print(
                        f"New best model for val loss : {val_loss:4.2}! saving the best model.."
                    )
                    save_model(model, epoch, val_loss, val_acc, save_dir,
                               args.model)
                    best_val_loss = val_loss
                elif val_acc > best_val_acc:
                    print(
                        f"New best model for val accuracy : {val_acc:4.2%}! saving the best model.."
                    )
                    save_model(model, epoch, val_loss, val_acc, save_dir,
                               args.model)
                    best_val_acc = val_acc

            print(
                f"[Val] acc: {val_acc:4.2%}, loss: {val_loss:4.2} || "
                f"best acc: {best_val_acc:4.2%}, best loss: {best_val_loss:4.2}"
            )
            wandb.log({"Val/loss": val_loss, "Val/accuracy": val_acc})
            print()
def lamw(x):
    I = x > 1e-10
    y = torch.clone(x)
    y[I] = lambertw(x[I])
    return y
Ejemplo n.º 5
0
def sgd_step(optimizer, detach_dp=True):
    """Performs a single optimization step using the SGD optimizer. The code
    has been copied from:

        https://git.io/fjYit

    Note, this function does not change the inner state of the given
    optimizer object.

    Note, gradients are cloned and detached by default.

    Args:
        optimizer: An instance of class :class:`torch.optim.SGD`.
        detach_dp: Whether gradients are detached from the computational
            graph. Note, :code:`False` only makes sense if
            func:`torch.autograd.backward` was called with the argument
            `create_graph` set to :code:`True`.

    Returns:
        A list of gradient changes `d_p` that would be applied by this
        optimizer to all parameters when calling :meth:`torch.optim.SGD.step`.
    """
    assert (isinstance(optimizer, optim.SGD))

    d_ps = []

    for group in optimizer.param_groups:
        weight_decay = group['weight_decay']
        momentum = group['momentum']
        dampening = group['dampening']
        nesterov = group['nesterov']

        for p in group['params']:
            if p.grad is None:
                continue

            if detach_dp:
                d_p = p.grad.detach().clone()
            else:
                d_p = p.grad.clone()

            if weight_decay != 0:
                d_p.add_(weight_decay, p.data)
            if momentum != 0:
                orig_state = dict(optimizer.state[p])
                param_state = dict()

                if 'momentum_buffer' in orig_state:
                    param_state['momentum_buffer'] = \
                        orig_state['momentum_buffer'].clone()

                if 'momentum_buffer' not in param_state:
                    buf = torch.clone(d_p).detach()
                else:
                    buf = param_state['momentum_buffer']
                    buf.mul_(momentum).add_(1 - dampening, d_p)
                    #buf = buf.mul(momentum).add(1 - dampening, d_p)
                if nesterov:
                    d_p = d_p.add(momentum, buf)
                else:
                    d_p = buf

            d_ps.append(-group['lr'] * d_p)

    return d_ps
Ejemplo n.º 6
0
    def step(self,
             closure: Optional[Callable[[], float]] = None) -> Optional[float]:
        """Performs a single optimization step.

        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None and isinstance(closure, collections.Callable):
            loss = closure()

        # step counter must be stored in state to ensure correct behavior under
        # optimizer sharding
        if "k" not in self.state:
            self.state["k"] = torch.tensor([0], dtype=torch.long)
        k = self.state["k"].item()

        for group in self.param_groups:
            eps = group["eps"]
            lr = group["lr"] + eps
            decay = group["weight_decay"]
            momentum = group["momentum"]

            ck = 1 - momentum
            lamb = lr * math.pow(k + 1, 0.5)

            for p in group["params"]:
                if p.grad is None:
                    continue
                grad = p.grad.data
                state = self.state[p]

                if "grad_sum_sq" not in state:
                    state["grad_sum_sq"] = torch.zeros_like(p.data).detach()
                    state["s"] = torch.zeros_like(p.data).detach()
                    if momentum != 0:
                        state["x0"] = torch.clone(p.data).detach()

                if momentum != 0.0 and grad.is_sparse:
                    raise RuntimeError(
                        "momentum != 0 is not compatible with sparse gradients"
                    )

                grad_sum_sq = state["grad_sum_sq"]
                s = state["s"]

                # Apply weight decay - L2 / AdamW style
                if decay:
                    p.data.mul_(1 - lr * decay)
                """ original impl:
                if decay != 0:
                    if grad.is_sparse:
                        raise RuntimeError("weight_decay option is not compatible with sparse gradients")

                    grad.add_(p.data, alpha=decay)
                """

                if grad.is_sparse:
                    grad = grad.coalesce()
                    grad_val = grad._values()

                    p_masked = p.sparse_mask(grad)
                    grad_sum_sq_masked = grad_sum_sq.sparse_mask(grad)
                    s_masked = s.sparse_mask(grad)

                    # Compute x_0 from other known quantities
                    rms_masked_vals = grad_sum_sq_masked._values().pow(
                        1 / 3).add_(eps)
                    x0_masked_vals = p_masked._values().addcdiv(
                        s_masked._values(), rms_masked_vals, value=1)

                    # Dense + sparse op
                    grad_sq = grad * grad
                    grad_sum_sq.add_(grad_sq, alpha=lamb)
                    grad_sum_sq_masked.add_(grad_sq, alpha=lamb)

                    rms_masked_vals = grad_sum_sq_masked._values().pow_(
                        1 / 3).add_(eps)

                    s.add_(grad, alpha=lamb)
                    s_masked._values().add_(grad_val, alpha=lamb)

                    # update masked copy of p
                    p_kp1_masked_vals = x0_masked_vals.addcdiv(
                        s_masked._values(), rms_masked_vals, value=-1)
                    # Copy updated masked p to dense p using an add operation
                    p_masked._values().add_(p_kp1_masked_vals, alpha=-1)
                    p.data.add_(p_masked, alpha=-1)
                else:
                    if momentum == 0:
                        # Compute x_0 from other known quantities
                        rms = grad_sum_sq.pow(1 / 3).add_(eps)
                        x0 = p.data.addcdiv(s, rms, value=1)
                    else:
                        x0 = state["x0"]

                    # Accumulate second moments
                    grad_sum_sq.addcmul_(grad, grad, value=lamb)
                    rms = grad_sum_sq.pow(1 / 3).add_(eps)

                    # Update s
                    s.data.add_(grad, alpha=lamb)

                    # Step
                    if momentum == 0:
                        p.data.copy_(x0.addcdiv(s, rms, value=-1))
                    else:
                        z = x0.addcdiv(s, rms, value=-1)

                        # p is a moving average of z
                        p.data.mul_(1 - ck).add_(z, alpha=ck)

        self.state["k"] += 1
        return loss
Ejemplo n.º 7
0
def batched_powerSGD_hook(state: PowerSGDState,
                          bucket) -> torch.futures.Future:
    """
    This DDP communication hook implements a simplified PowerSGD gradient compression
    algorithm described in https://arxiv.org/abs/1905.13727.
    Once gradient tensors are aggregated across all workers, this hook applies
    compression to the flattened input tensor that batches per-parameter tensors as follows:
    1) Views the input flattened 1D gradient tensor as a square-shaped tensor M with 0 paddings;
    2) Creates two low-rank tensors P and Q for decomposing M,
    such that M = PQ^T, where Q is initialized from a standard normal distribution and orthogonalized;
    2) Computes P, which is equal to MQ;
    3) Allreduces P;
    4) Orthogonalizes P;
    5) Computes Q, which is approximately equal to M^TP;
    6) Allreduces Q;
    7) Computes M, which is approximately equal to PQ^T.
    8) Truncates the input tensor to the original length.

    This variant is faster than `powerSGD_hook` that runs layer-wise gradient compression,
    but it usually results in a much lower accuracy, unless `matrix_approximation_rank` in the state is 1.
    Increasing `matrix_approximation_rank` may not necessarily increase the accuracy,
    because batching per-parameter tensors without column/row alignment can destroy low-rank structure.
    Therefore, the user shoud always consider `powerSGD_hook` first,
    and only consider this variant when a satisfying accuracy can be achieved when `matrix_approximation_rank` is 1.

    Note that this communication hook enforces vanilla allreduce for the first `state.start_powerSGD_iter` iterations.
    This can not only allow the user to have a finer tuning over the tradeoff between speedup and accuracy,
    but also help abstract away some complexity of the internal optimization of DDP for future communication hook developers.

    TODO(wayi@): The above procedure does two matmul+allreduce steps per iteration --
    one left multiplication and one right multiplication.
    For warm-start, can take one such step at a time, and alternate between them.

    Args:
        state (PowerSGDState): State information to configure the compression rate and support error feedback, warm start, etc.
            To tune the compression configs, see Note [Guidance to Tune `matrix_approximation_rank` And `start_powerSGD_iter`].
        bucket (dist._GradBucket): Bucket that stores a 1D flattened gradient tensor that batches multiple per-variable tensors.
            Note that since DDP comm hook only supports single process single device mode at this time,
            only exactly one tensor is stored in this bucket.

    Returns:
        Future handler of the communication, which updates the gradients in place.

    Example::
        state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1)
        >>> ddp_model.register_comm_hook(state, batched_powerSGD_hook)
    """
    process_group = state.process_group
    group_to_use = process_group if process_group is not None else dist.group.WORLD
    world_size = group_to_use.size()

    # The input tensor is a flattened 1D tensor.
    input_tensor = bucket.get_tensors()[0]

    # Run vanilla allreduce in the first `start_powerSGD_iter` iterations.
    if state.iter < state.start_powerSGD_iter:
        state.maybe_increase_iter(bucket)
        return default._allreduce_fut(group_to_use, input_tensor)

    # Apply PowerSGD after `start_powerSGD_iter` iterations.
    device = input_tensor.device
    total_length = input_tensor.shape[0]

    # View the input tensor as a 2D square-shape tensor, and pad 0s if necessary.
    square_side_length = math.ceil(math.sqrt(total_length))
    padded_total_length = square_side_length**2
    input_tensor.resize_(padded_total_length)
    input_tensor[total_length:padded_total_length].fill_(0)

    # Incorporate the error from the previous state into the gradients.
    bucket_index = bucket.get_index()
    input_tensor_cp = None
    if state.use_error_feedback:
        if bucket_index in state.error_dict:
            input_tensor.add_(state.error_dict[bucket_index])
        else:
            logging.info(
                "A zero tensor of length {} that represents local error is created."
                .format(padded_total_length))
            state.error_dict[bucket_index] = torch.zeros(
                padded_total_length, device=device, dtype=input_tensor.dtype)

        # Keep a copy of the input tensor,
        # so that we can compute the local error caused by compression later,
        # by comparing this copy and the input tensor updated after decompression.
        input_tensor_cp = torch.clone(input_tensor).detach()
    matrix = input_tensor.view(square_side_length, square_side_length)

    # Reuse P and Q from the previous iteration if possible.
    # The memory spaces of P and Q need to be allocated in the first iteration when PowerSGD is applied.
    if not state.warm_start or bucket_index not in state.p_memory_dict:
        # If warm-start is disabled, low-rank tensors will be initialized at every step.
        # Only log this if warm-start to avoid spamming.
        if state.warm_start:
            logging.info(
                "Initializing low-rank tensors P and Q, each of which has a shape of {} x {}."
                .format(square_side_length, state.matrix_approximation_rank))

        def create_low_rank_tensor(fill_random_values, rng):
            "Returns a low-rank 2D tensor of square_side_length * matrix_approximation_rank."
            if fill_random_values:
                with torch.random.fork_rng(devices=[]):
                    # Fork this RNG to avoid changing the seed globally and affecting the random sampling
                    # anywhere else in the training.
                    # The seed makes sure that the initial random values are the same across all the DDP replicas.
                    # This seed should differ at every step.
                    # Since it is very slow to fork RNG state across all the CUDA devices,
                    # only fork on CPU and then move the generated tensor to the CUDA device.
                    torch.manual_seed(rng.randint(1_000_000_000))
                    return torch.randn(
                        square_side_length,
                        state.matrix_approximation_rank,
                        device="cpu",
                        dtype=input_tensor.dtype,
                    ).to(device)
            else:
                return torch.empty(
                    square_side_length,
                    state.matrix_approximation_rank,
                    device=device,
                    dtype=input_tensor.dtype,
                )

        state.p_memory_dict[bucket_index] = create_low_rank_tensor(
            fill_random_values=False, rng=state.rng)
        state.q_memory_dict[bucket_index] = create_low_rank_tensor(
            fill_random_values=True, rng=state.rng)
    _orthogonalize(state.q_memory_dict[bucket_index], 0)

    torch.matmul(matrix,
                 state.q_memory_dict[bucket_index],
                 out=state.p_memory_dict[bucket_index])
    allreduce_p_fut = dist.all_reduce(state.p_memory_dict[bucket_index],
                                      group=group_to_use,
                                      async_op=True).get_future()

    def compute_q(fut):
        state.p_memory_dict[bucket_index] = fut.value()[0]
        _orthogonalize(state.p_memory_dict[bucket_index], 0)

        torch.matmul(
            matrix.t(),
            state.p_memory_dict[bucket_index],
            out=state.q_memory_dict[bucket_index],
        )

        return [
            dist.all_reduce(state.q_memory_dict[bucket_index],
                            group=group_to_use,
                            async_op=True).get_future().wait()[0]
        ]

    def decompress(fut):
        state.q_memory_dict[bucket_index] = fut.value()[0].div_(world_size)
        torch.matmul(
            state.p_memory_dict[bucket_index],
            state.q_memory_dict[bucket_index].t(),
            out=matrix,
        )

        if state.use_error_feedback:
            # Memorize the local errors.
            state.error_dict[bucket_index] = input_tensor_cp - input_tensor
        if torch.cuda.is_available():
            torch.cuda.synchronize(device)
        if not state.warm_start:
            state.p_memory_dict.clear()
            state.q_memory_dict.clear()
        ret = input_tensor.resize_(total_length)

        state.maybe_increase_iter(bucket)

        return [ret]

    return allreduce_p_fut.then(compute_q).then(decompress)
Ejemplo n.º 8
0
	def step(self, reg_params, closure = None):

		loss = None

		if closure is not None:
			loss = closure()

		
		for group in self.param_groups:
			weight_decay = group['weight_decay']
			momentum = group['momentum']
			dampening = group['dampening']
			nesterov = group['nesterov']

			for p in group['params']:
				
				if p.grad is None:
					continue

				d_p = p.grad.data

				if p in reg_params:

					param_dict = reg_params[p]

					omega = param_dict['omega']
					init_val = param_dict['init_val']

					curr_param_value = p.data
					curr_param_value = curr_param_value.cuda()
					
					init_val = init_val.cuda()
					omega = omega.cuda()

					#get the difference
					param_diff = curr_param_value - init_val

					#get the gradient for the penalty term for change in the weights of the parameters
					local_grad = torch.mul(param_diff, 2*self.reg_lambda*omega)
					
					del param_diff
					del omega
					del init_val
					del curr_param_value

					d_p = d_p + local_grad
					
					del local_grad
					
				
				if (weight_decay != 0):
					d_p.add_(weight_decay, p.data)

				if (momentum != 0):
					param_state = self.state[p]
					if 'momentum_buffer' not in param_state:
						buf = param_state['momentum_buffer'] = torch.clone(d_p).detach()
					else:
						buf = param_state['momentum_buffer']
						buf.mul_(momentum).add_(1 - dampening, d_p)
					if nesterov:
						d_p = d_p.add(momentum, buf)
					else:
						d_p = buf

				p.data.add_(-group['lr'], d_p)

		return loss
Ejemplo n.º 9
0
def train_cnn_pytorch():
    image_dim = 32
    hidden_dim = 200
    output_dim = 10
    kernel_dim = 3
    kernel_num = 64
    batch_size = 8
    lr = 0.01
    dp_rate = 0.3
    epochs = 1000

    best_result = [0, 0]
    no_update = 0

    # os.environ["CUDA_VISIBLE_DEVICES"] = 0
    print("Start training")
    model = CNN(batch_size=batch_size,
                input_dim=image_dim,
                hidden_dim=hidden_dim,
                output_dim=output_dim,
                kernel_num=kernel_num,
                kernel_dim=kernel_dim,
                dp_rate=dp_rate)
    if torch.cuda.is_available():
        model.cuda()
    optimizer = optim.SGD(model.parameters(), lr=lr)

    for epoch in range(epochs):
        train_data = MyDataset("digits/trainingDigits")
        train_loader = data.DataLoader(train_data,
                                       batch_size=batch_size,
                                       num_workers=0,
                                       shuffle=True)
        model.train()
        start = time.time()
        print(f"Epoch {epoch} start ")
        avg_loss = 0
        count = 0
        for step, input_data in enumerate(train_loader):
            x = torch.clone(input_data[0]).float()
            target = torch.clone(input_data[1]).long()
            if torch.cuda.is_available():
                x = x.cuda()
                target = target.cuda()
            prediction = model(x)
            loss = F.cross_entropy(prediction, target.argmax(dim=1))
            avg_loss += loss.item()
            count += 1
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        avg_loss /= len(train_data)
        end = time.time()
        print(
            f"Epoch {epoch} done, Train average loss: {avg_loss}, costing time: {end - start}"
        )

        if epoch % 50 == 0:
            accuracy, wrong_numbers = evaluate_cnn_pytorch(model, batch_size)
            if accuracy > best_result[0]:
                best_result[0] = accuracy
                best_result[1] = wrong_numbers
                no_update = 0
            else:
                no_update += 1
        if no_update >= 5:
            print("Best Accuracy on test data: " + str(best_result[0]) + "%")
            print(f"Best wrong_numbers: {best_result[1]}")
            exit()
    print("Best Accuracy on test data: " + str(best_result[0]) + "%")
    print(f"Best wrong_numbers: {best_result[1]}")
Ejemplo n.º 10
0
def rotate(img):
    imgs = []
    imgs.append(torch.rot90(torch.clone(img), k=1, dims=[1, 2]))
    imgs.append(torch.rot90(torch.clone(img), k=2, dims=[1, 2]))
    imgs.append(torch.rot90(torch.clone(img), k=3, dims=[1, 2]))
    return imgs
Ejemplo n.º 11
0
def powerSGD_hook(
    state: PowerSGDState,
    bucket,
) -> torch.futures.Future:
    """
    This DDP communication hook implements a simplified PowerSGD gradient compression
    algorithm described in https://arxiv.org/abs/1905.13727.
    Once gradient tensors are aggregated across all workers, this hook applies
    compression as follows:
    1) Views the input flattened 1D gradient tensor as a square-shaped tensor M with 0 paddings;
    2) Creates two low-rank tensors P and Q for decomposing M,
    such that M = PQ^T, where Q is initialized from a standard normal distribution and orthogonalized;
    2) Computes P, which is equal to MQ;
    3) Allreduces P;
    4) Orthogonizes P;
    5) Computes Q, which is approximately equal to M^TP;
    6) Allreduces Q;
    7) Computes M, which is approximately equal to PQ^T.
    8) Truncates the input tensor to the original length.

    TODO(wayi@): The above procedure does two matmul+allreduce steps per iteration --
    one left multiplication and one right multiplication.
    For warm start, can take one such step at a time, and alternate between them.

    Arguments:
        state (PowerSGDState): State information to configure the compression rate and support error feedback, warm start, etc.
        bucket (dist._GradBucket): Bucket that stores a 1D flattened gradient tensor that batches multiple per-variable tensors.
            Note that since DDP comm hook only supports single process single device mode at this time,
            only exactly one tensor is stored in this bucket.
        matrix_approximation_rank (int): The low rank for matrix approximation.
            Typically only 1 or 2 is used. See https://arxiv.org/pdf/1905.13727.pdf.

    Returns:
        Future handler of the communication, which updates the gradients in place.

    Example::
        state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1)
        >>> ddp_model.register_comm_hook(state, powerSGD_hook)
    """
    process_group = state.process_group
    group_to_use = process_group if process_group is not None else dist.group.WORLD
    world_size = (
        process_group.size() if process_group is not None else dist.get_world_size()
    )

    # The input tensor is a flattened 1D tensor.
    input_tensor = bucket.get_tensors()[0]
    device = input_tensor.device
    total_length = input_tensor.shape[0]

    # View the input tensor as a 2D square-shape tensor, and pad 0s if necessary.
    square_side_length = math.ceil(math.sqrt(total_length))
    padded_total_length = square_side_length ** 2
    input_tensor.resize_(padded_total_length)
    input_tensor[total_length:padded_total_length].fill_(0)

    # Incorporate the error from the previous state into the gradients.
    bucket_index = bucket.get_index()
    if state.use_error_feedback:
        # The buckets can be rebuilt during training.
        # In this case, the error tensor shape will not be aligned with the input tensor,
        # and the error will be re-initialized as zeros.
        if (
            bucket_index in state.error_dict
            and state.error_dict[bucket_index].shape[0] == padded_total_length
        ):
            input_tensor.add_(state.error_dict[bucket_index])
        else:
            logging.info(
                "A zero tensor of length {} that represents local error is created.".format(
                    padded_total_length
                )
            )
            state.error_dict[bucket_index] = torch.zeros(
                padded_total_length, device=device
            )

        # Keep a copy of the input tensor,
        # so that we can compute the local error caused by compression later,
        # by comparing this copy and the input tensor updated after decompression.
        input_tensor_cp = torch.clone(input_tensor).detach()
    matrix = input_tensor.view(square_side_length, square_side_length)

    def create_low_rank_tensor(fill_random_values, rng):
        "Returns a low-rank 2D tensor of square_side_length * matrix_approximation_rank."
        if fill_random_values:
            with torch.random.fork_rng(devices=[]):
                # Fork this RNG to avoid changing the seed globally and affecting the random sampling anywhere else in the training.
                # The seed makes sure that the initial random values are the same across all the DDP replicas.
                # Such seed should differ at every step.
                # Since it is very slow to fork RNG state across all the CUDA devices,
                # only fork on CPU and then move the generated tensor to the CUDA device.
                torch.manual_seed(rng.randint(1_000_000_000))
                return torch.randn(
                    square_side_length, state.matrix_approximation_rank, device="cpu"
                ).to(device)
        else:
            return torch.empty(
                square_side_length, state.matrix_approximation_rank, device=device
            )

    p = create_low_rank_tensor(fill_random_values=False, rng=state.rng)
    q = create_low_rank_tensor(fill_random_values=True, rng=state.rng)
    _orthogonalize(q, 0)

    torch.matmul(matrix, q, out=p)
    allreduce_p_fut = dist.all_reduce(p, group=group_to_use, async_op=True).get_future()

    def compute_q(fut):
        p = fut.value()[0]
        _orthogonalize(p, 0)

        torch.matmul(matrix.t(), p, out=q)

        return [
            dist.all_reduce(q, group=group_to_use, async_op=True)
            .get_future()
            .value()[0]
        ]

    def decompress(fut):
        q = fut.value()[0].div_(world_size)
        torch.matmul(p, q.t(), out=matrix)

        if state.use_error_feedback:
            # Memorize the local errors.
            state.error_dict[bucket_index] = input_tensor_cp - input_tensor
        ret = input_tensor.resize_(total_length)
        return [ret]

    return allreduce_p_fut.then(compute_q).then(decompress)
Ejemplo n.º 12
0
    def step(self, closure=None):
        assert self.is_ml
        # NCE is only defined for AGD
        if self.NCE:
            assert self.momentum > 0
        # need to define closure to use noise, AGD, or NCE
        if self.noise_r > 0 or self.momentum > 0 or self.NCE:
            assert closure != None
        loss = None
        if closure is not None:
            loss = closure()

        params = []
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                params.append(p)

        # noise
        if self.noise_r > 0:
            param_state = self.state[params[0]]
            if 'noise_count' not in param_state:
                param_state['noise_count'] = 0
            all_grad = torch.cat([p.grad.reshape(-1) for p in params], dim=0)
            all_grad_norm = torch.norm(all_grad, p=2).item()
            if self.is_verbose:
                print("*Noise part* grad l2 norm: %.3f, count: %d" %
                      (all_grad_norm, param_state['noise_count']))
            if all_grad_norm <= self.noise_eps and param_state[
                    'noise_count'] >= self.noise_T:
                param_state['noise_count'] = 0
                radius = torch.pow(torch.rand(1),
                                   1.0 / all_grad.numel()).mul(self.noise_r)
                gauss = torch.randn(all_grad.size())
                normed_gauss = gauss.div(torch.norm(gauss, p=2))
                noise = normed_gauss.mul(radius)
                if self.is_verbose:
                    print("add noise with l2 norm:", radius)
                i = 0
                for p in params:
                    p.data.add_(noise[i:i + p.numel()].reshape(p.size()))
                    i = i + p.numel()
            else:
                param_state['noise_count'] += 1
                if self.is_verbose:
                    print("no noise added")

        # AGD
        if self.momentum > 0:
            if self.NCE:
                xt = torch.tensor([])
                yt = torch.tensor([])
                g_yt = torch.tensor([])
                vt = torch.tensor([])
            for p in params:
                param_state = self.state[p]
                if 'momentum_buffer' not in param_state:
                    buf = param_state['momentum_buffer'] = torch.zeros(
                        p.size())
                else:
                    buf = param_state['momentum_buffer']
                if self.NCE:
                    xt = torch.cat(
                        [xt, torch.clone(p.data).detach().reshape(-1)], dim=0)
                    vt = torch.cat(
                        [vt, torch.clone(buf.data).detach().reshape(-1)],
                        dim=0)
                    f_xt = loss.item()
                p.data.add_(self.momentum, buf)
            loss = closure()
            for p in params:
                buf = self.state[p]['momentum_buffer']
                if self.NCE:
                    yt = torch.cat(
                        [yt, torch.clone(p.data).detach().reshape(-1)], dim=0)
                    g_yt = torch.cat(
                        [g_yt,
                         torch.clone(p.grad.data).detach().reshape(-1)],
                        dim=0)
                    f_yt = loss.item()
                p.data.add_(-self.lr, p.grad.data)
                buf.mul_(self.momentum).add_(-self.lr, p.grad.data)
        else:
            for p in params:
                p.data.add_(-self.lr, p.grad.data)

        # NCE
        def copy_to_p(params, update):
            i = 0
            for p in params:
                p.data = update[i:i + p.numel()].reshape(p.size())
                i = i + p.numel()

        if self.NCE:
            norm_vt = torch.norm(vt, p=2)
            if self.is_verbose:
                print(
                    "*NCE part* f(xt): %.3f, f(yt): %.3f, <grad_f(yt),xt-yt>: %.3f, ||xt-yt||^2: %.3f, ||vt||: %.3f"
                    % (f_xt, f_yt, g_yt.dot(
                        (xt - yt)).item(), torch.norm(
                            (xt - yt), p=2).pow(2).item(), norm_vt.item()))
            if norm_vt > 0 and f_xt <= f_yt + g_yt.dot(
                (xt - yt)) - self.NCE_gamma / 2 * (torch.norm(
                    (xt - yt), p=2).pow(2)):
                for p in params:
                    self.state[p]['momentum_buffer'] = torch.zeros(p.size())
                if norm_vt >= self.NCE_s:
                    copy_to_p(params, xt)
                    if self.is_verbose:
                        print("setting x_{t+1} = xt")
                else:
                    delta = vt.mul(self.NCE_s).div(norm_vt)
                    copy_to_p(params, xt + delta)
                    loss_ = closure()
                    copy_to_p(params, xt - delta)
                    loss = closure()
                    if (loss_ < loss):
                        if self.is_verbose:
                            print("setting x_{t+1} = xt + delta")
                        copy_to_p(params, xt + delta)
                        loss = closure()
                    else:
                        if self.is_verbose:
                            print("setting x_{t+1} = xt - delta")
            else:
                if self.is_verbose:
                    print("no change by NCE")
        return loss
Ejemplo n.º 13
0
    def forward(self, x, vars=None, bn_training=True):
        """
        This function can be called by finetunning, however, in finetunning, we dont wish to update
        running_mean/running_var. Thought weights/bias of bn is updated, it has been separated by fast_weights.
        Indeed, to not update running_mean/running_var, we need set update_bn_statistics=False
        but weight/bias will be updated and not dirty initial theta parameters via fast_weiths.
        :param x: [b, 1, 28, 28]
        :param vars:
        :param bn_training: set False to not update
        :return: x, loss, likelihood, kld
        """

        if vars is None:
            vars = self.vars

        idx = 0
        bn_idx = 0
        first_upsample = True

        blocks = []
        for name, param in self.config:
            if name == 'conv2d':
                # periodic padding:
                kernel_width = param[2]
                if kernel_width % 2:
                    x_pad = torch.cat([
                        x[:, :, :, x.shape[3] -
                          int((kernel_width - 1) / 2):x.shape[3]], x
                    ],
                                      dim=3)
                    x_pad = torch.cat(
                        [x_pad, x[:, :, :, 0:int((kernel_width - 1) / 2)]],
                        dim=3)
                    x_pad = torch.cat([
                        torch.zeros_like(x_pad[:, :, 0:int((kernel_width - 1) /
                                                           2), :]), x_pad
                    ],
                                      dim=2)
                    x_pad = torch.cat([
                        x_pad,
                        torch.zeros_like(
                            x_pad[:, :, 0:int((kernel_width - 1) / 2), :])
                    ],
                                      dim=2)
                else:
                    x_pad = torch.cat([
                        x[:, :, :,
                          x.shape[3] - int(kernel_width / 2):x.shape[3]], x
                    ],
                                      dim=3)
                    x_pad = torch.cat(
                        [x_pad, x[:, :, :, 0:int((kernel_width / 2) - 1)]],
                        dim=3)
                    x_pad = torch.cat([
                        torch.zeros_like(
                            x_pad[:, :, 0:int(kernel_width / 2), :]), x_pad
                    ],
                                      dim=2)
                    x_pad = torch.cat([
                        x_pad,
                        torch.zeros_like(x_pad[:, :,
                                               0:int(kernel_width / 2) - 1, :])
                    ],
                                      dim=2)

                w = vars[idx]
                # remember to keep synchrozied of forward_encoder and forward_decoder!
                x = F.conv2d(x_pad, w, stride=param[4], padding=0)
                idx += 1
                if param[6]:
                    self.residual_terms = torch.clone(x)
                # print(name, param, '\tout:', x.shape)
            elif name == 'conv2d_b':
                kernel_width = param[2]
                if kernel_width % 2:
                    x_pad = torch.cat([
                        x[:, :, :, x.shape[3] -
                          int((kernel_width - 1) / 2):x.shape[3]], x
                    ],
                                      dim=3)
                    x_pad = torch.cat(
                        [x_pad, x[:, :, :, 0:int((kernel_width - 1) / 2)]],
                        dim=3)
                    x_pad = torch.cat([
                        torch.zeros_like(x_pad[:, :, 0:int((kernel_width - 1) /
                                                           2), :]), x_pad
                    ],
                                      dim=2)
                    x_pad = torch.cat([
                        x_pad,
                        torch.zeros_like(
                            x_pad[:, :, 0:int((kernel_width - 1) / 2), :])
                    ],
                                      dim=2)
                else:
                    x_pad = torch.cat([
                        x[:, :, :,
                          x.shape[3] - int(kernel_width / 2):x.shape[3]], x
                    ],
                                      dim=3)
                    x_pad = torch.cat(
                        [x_pad, x[:, :, :, 0:int((kernel_width / 2) - 1)]],
                        dim=3)
                    x_pad = torch.cat([
                        torch.zeros_like(
                            x_pad[:, :, 0:int(kernel_width / 2), :]), x_pad
                    ],
                                      dim=2)
                    x_pad = torch.cat([
                        x_pad,
                        torch.zeros_like(x_pad[:, :,
                                               0:int(kernel_width / 2) - 1, :])
                    ],
                                      dim=2)

                w, b = vars[idx], vars[idx + 1]
                # remember to keep synchrozied of forward_encoder and forward_decoder!
                x = F.conv2d(x_pad, w, b, stride=param[4], padding=0)
                idx += 2
                # print(name, param, '\tout:', x.shape)

            # elif name == 'convt2d':
            #     w, b = vars[idx], vars[idx + 1]
            #     # remember to keep synchrozied of forward_encoder and forward_decoder!
            #     x = F.conv_transpose2d(x, w, b, stride=param[4], padding=param[5])
            #     idx += 2
            #     # print(name, param, '\tout:', x.shape)
            # elif name == 'linear':
            #     w, b = vars[idx], vars[idx + 1]
            #     x = F.linear(x, w, b)
            #     idx += 2
            #     # print('forward:', idx, x.norm().item())
            elif name == 'bn':
                w, b = vars[idx], vars[idx + 1]
                running_mean, running_var = self.vars_bn[bn_idx], self.vars_bn[
                    bn_idx + 1]
                x = F.batch_norm(x,
                                 running_mean,
                                 running_var,
                                 weight=w,
                                 bias=b,
                                 training=bn_training)
                idx += 2
                bn_idx += 2

            # elif name == 'flatten':
            #     # print(x.shape)
            #     x = x.view(x.size(0), -1)
            # elif name == 'reshape':
            #     # [b, 8] => [b, 2, 2, 2]
            #     x = x.view(x.size(0), *param)
            # elif name == 'relu':
            #     x = F.relu(x, inplace=param[0])
            elif name == 'leakyrelu':
                x = F.leaky_relu(x, negative_slope=param[0], inplace=param[1])
            # elif name == 'tanh':
            #     x = F.tanh(x)
            # elif name == 'sigmoid':
            #     x = torch.sigmoid(x)
            elif name == 'upsample':
                if first_upsample:
                    first_upsample = False
                    x = blocks.pop()
                shortcut = blocks.pop()
                x = F.interpolate(x,
                                  size=(shortcut.shape[2], shortcut.shape[3]),
                                  mode='nearest')
                x = torch.cat([shortcut, x], dim=1)  # batch, channels, h, w

            elif name == 'residual':
                x = x + self.residual_terms

            elif name == 'max_pool2d':
                blocks.append(x)
                x = F.max_pool2d(x,
                                 param[0],
                                 stride=param[1],
                                 padding=param[2])
            # elif name == 'avg_pool2d':
            #     x = F.avg_pool2d(x, param[0], param[1], param[2])

            else:
                print(name)
                raise NotImplementedError

        # make sure variable is used properly
        assert idx == len(vars)
        assert bn_idx == len(self.vars_bn)

        return x
Ejemplo n.º 14
0
def train(data_dir, model_dir, args):
    seed_everything(args.seed)

    save_dir = increment_path(os.path.join(model_dir, args.name))

    # -- settings
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")

    info = pd.read_csv('/opt/ml/input/data/train/train.csv')

    info['gender_age'] = info.apply(
        lambda x: convert_gender_age(x.gender, x.age), axis=1)
    n_fold = int(1 / args.val_ratio)

    skf = StratifiedKFold(n_splits=n_fold, shuffle=True)
    info.loc[:, 'fold'] = 0
    for fold_num, (train_index, val_index) in enumerate(
            skf.split(X=info.index, y=info.gender_age.values)):
        info.loc[info.iloc[val_index].index, 'fold'] = fold_num

    fold_idx = 0
    train = info[info.fold != fold_idx].reset_index(drop=True)
    val = info[info.fold == fold_idx].reset_index(drop=True)

    # -- dataset
    dataset_module = getattr(import_module("dataset"),
                             args.dataset)  # default: MaskDataset

    # -- augmentation
    train_transform_module = getattr(
        import_module("dataset"),
        args.train_augmentation)  # default: BaseAugmentation
    val_transform_module = getattr(
        import_module("dataset"),
        args.val_augmentation)  # default: BaseAugmentation

    train_transform = train_transform_module(
        resize=args.resize,
        mean=MEAN,
        std=STD,
    )
    val_transform = val_transform_module(
        resize=args.resize,
        mean=MEAN,
        std=STD,
    )

    print(train_transform.transform, val_transform.transform)

    if args.dataset == 'MaskDataset' or args.dataset == 'MaskOldDataset':
        if args.dataset == 'MaskOldDataset':
            old_transform_module = getattr(import_module('dataset'),
                                           args.old_augmentation)

            old_transform = old_transform_module(
                resize=args.resize,
                mean=MEAN,
                std=STD,
            )
            train_dataset = dataset_module(data_dir, train, train_transform,
                                           old_transform)
            if args.val_old:
                val_dataset = dataset_module(data_dir, val, val_transform,
                                             old_transform)
            else:
                val_dataset = dataset_module(data_dir, val, val_transform)
        else:
            train_dataset = dataset_module(data_dir, train, train_transform)
            val_dataset = dataset_module(data_dir, val, val_transform)
    else:
        dataset = dataset_module(data_dir=data_dir, )

        # dataset.set_transform(transform)
        # -- data_loader
        train_set, val_set = dataset.split_dataset()

        train_dataset = DatasetFromSubset(train_set, transform=train_transform)
        val_dataset = DatasetFromSubset(val_set, transform=val_transform)

    train_loader = DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        num_workers=4,
        shuffle=True,
        pin_memory=use_cuda,
        #drop_last=True,
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=args.valid_batch_size,
        num_workers=1,
        shuffle=False,
        pin_memory=use_cuda,
        #drop_last=True,
    )

    # -- model
    model_module = getattr(import_module("model"),
                           args.model)  # default: BaseModel
    model = model_module(num_classes=args.num_classes).to(device)
    model = torch.nn.DataParallel(model)

    # -- loss & metric
    if args.criterion == 'f1' or args.criterion == 'label_smoothing':
        criterion = create_criterion(args.criterion, classes=args.num_classes)
    else:
        criterion = create_criterion(args.criterion)

    opt_module = getattr(import_module("torch.optim"),
                         args.optimizer)  # default: SGD
    optimizer = opt_module(filter(lambda p: p.requires_grad,
                                  model.parameters()),
                           lr=args.lr,
                           weight_decay=5e-4)
    if args.scheduler == 'cosine':
        scheduler = CosineAnnealingLR(optimizer, T_max=2, eta_min=1e-6)
    elif args.scheduler == 'reduce':
        scheduler = ReduceLROnPlateau(optimizer, factor=0.5, patience=5)
    elif args.scheduler == 'step':
        scheduler = StepLR(optimizer, args.lr_decay_step, gamma=0.5)
    else:
        scheduler = None

    # -- logging
    logger = SummaryWriter(log_dir=save_dir)
    with open(os.path.join(save_dir, 'config.json'), 'w',
              encoding='utf-8') as f:
        json.dump(vars(args), f, ensure_ascii=False, indent=4)

    best_val_acc = 0
    best_val_loss = np.inf
    print("This notebook use [%s]." % (device))

    early_stopping = EarlyStopping(patience=args.patience, verbose=True)

    for epoch in range(args.epochs):
        # train loop
        model.train()
        loss_value = 0
        matches = 0

        train_loss, train_acc = AverageMeter(), AverageMeter()

        for idx, train_batch in enumerate(train_loader):
            inputs, labels = train_batch
            if args.dataset == 'MaskDataset' or args.dataset == 'MaskOldDataset':
                labels = labels.argmax(dim=-1)
            inputs = inputs.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()

            outs = model(inputs)
            preds = torch.argmax(outs, dim=-1)
            loss = criterion(outs, labels)

            loss.backward()
            optimizer.step()

            #loss_value += loss.item()
            #matches += (preds == labels).sum().item()
            acc = (preds == labels).sum().item() / len(labels)

            train_loss.update(loss.item(), len(labels))
            train_acc.update(acc, len(labels))

            if (idx + 1) % args.log_interval == 0:
                #train_loss = loss_value / args.log_interval
                #train_acc = matches / args.batch_size / args.log_interval
                train_f1_acc = f1_score(preds.cpu().detach().type(torch.int),
                                        labels.cpu().detach().type(torch.int),
                                        average='macro')
                current_lr = get_lr(optimizer)
                print(
                    f"Epoch[{epoch + 1}/{args.epochs}]({idx + 1}/{len(train_loader)}) || "
                    f"training loss {train_loss.avg:.4f} || training accuracy {train_acc.avg:4.2%} || train_f1_acc {train_f1_acc:.4} || lr {current_lr}"
                )
                logger.add_scalar("Train/loss", train_loss.avg,
                                  epoch * len(train_loader) + idx)
                logger.add_scalar("Train/accuracy", train_acc.avg,
                                  epoch * len(train_loader) + idx)

                loss_value = 0
                matches = 0

        scheduler.step()

        val_loss, val_acc = AverageMeter(), AverageMeter()
        # val loop
        with torch.no_grad():
            print("Calculating validation results...")
            model.eval()
            val_labels_items = np.array([])
            val_preds_items = np.array([])
            figure = None
            for val_batch in val_loader:
                inputs, labels = val_batch
                if args.dataset == 'MaskDataset' or args.dataset == 'MaskOldDataset':
                    labels = labels.argmax(dim=-1)

                inputs = inputs.to(device)
                labels = labels.to(device)

                outs = model(inputs)
                preds = torch.argmax(outs, dim=-1)

                #loss_item = criterion(outs, labels).item()
                #acc_item = (labels == preds).sum().item()
                #val_loss_items.append(loss_item)
                #val_acc_items.append(acc_item)

                loss = criterion(outs, labels)
                acc = (preds == labels).sum().item() / len(labels)

                val_loss.update(loss.item(), len(labels))
                val_acc.update(acc, len(labels))

                val_labels_items = np.concatenate(
                    [val_labels_items, labels.cpu().numpy()])
                val_preds_items = np.concatenate(
                    [val_preds_items, preds.cpu().numpy()])

                if figure is None:
                    if epoch % 2:
                        images, labels, preds = get_all_datas(
                            model, device, val_loader)
                        figure = log_confusion_matrix(
                            labels.cpu().numpy(),
                            np.argmax(preds.cpu().numpy(), axis=1),
                            args.num_classes)
                        # figure2 = plots_result(images.cpu().numpy()[:36], labels.cpu().numpy()[:36], preds.cpu().numpy()[:36], args.num_classes, title="plots_result")
                    else:
                        inputs_np = torch.clone(inputs).detach().cpu().permute(
                            0, 2, 3, 1).numpy()
                        inputs_np = val_dataset.denormalize_image(
                            inputs_np, MEAN, STD)
                        figure = grid_image(inputs_np, labels, preds, 9, False)

            # val_loss = np.sum(val_loss_items) / len(val_loader)
            # val_acc = np.sum(val_acc_items) / len(val_set)
            val_f1_acc = f1_score(val_labels_items.astype(np.int),
                                  val_preds_items.astype(np.int),
                                  average='macro')

            best_val_acc = max(best_val_acc, val_acc.avg)
            # best_val_loss = min(best_val_loss, val_loss)
            if val_loss.avg < best_val_loss:
                print(
                    f"New best model for val loss : {val_loss.avg:4.2%}! saving the best model.."
                )
                torch.save(model.module.state_dict(), f"{save_dir}/best.pth")
                best_val_loss = val_loss.avg
            torch.save(model.module.state_dict(), f"{save_dir}/last.pth")
            print(
                f"[Val] acc : {val_acc.avg:4.2%}, loss : {val_loss.avg:.4f} || val_f1_acc : {val_f1_acc:.4} || "
                f"best acc : {best_val_acc:4.2%}, best loss : {best_val_loss:.4f}"
            )
            logger.add_scalar("Val/loss", val_loss.avg, epoch)
            logger.add_scalar("Val/accuracy", val_acc.avg, epoch)
            logger.add_figure("results", figure, epoch)
            # logger.add_figure("results1", figure2, epoch)

            early_stopping(val_loss.avg, model)

            if early_stopping.early_stop:
                print('Early stopping...')
                break

            print()
Ejemplo n.º 15
0
    def forward(self, x=torch.rand(1, 3, 640, 640), requires_grad=True):
        if self.device:
            x = Variable(x, requires_grad=requires_grad).cuda()
        else:
            x = Variable(x, requires_grad=requires_grad)

        ### indexing model:
        x_ind = F.interpolate(x, size=160)

        x_ind = self.down0_conv(x_ind)
        x_ind = self.down0_conv_bn(x_ind)
        x_ind = F.leaky_relu(x_ind)

        x_ind, index = self.mp2d(x_ind)

        x_ind = self.down1_conv(x_ind)
        x_ind = self.down1_conv_bn(x_ind)
        x_ind = F.leaky_relu(x_ind)

        x_ind = self.down2_conv(x_ind)
        x_ind = self.down2_conv_bn(x_ind)
        x_ind = F.leaky_relu(x_ind)

        x_ind = self.down2_1_conv(x_ind)
        x_ind = self.down2_1_conv_bn(x_ind)
        x_ind = F.leaky_relu(x_ind)

        x_ind = self.down2_2_conv(x_ind)
        x_ind = self.down2_2_conv_bn(x_ind)
        x_ind = F.leaky_relu(x_ind)

        x_ind = self.down4_conv(x_ind)
        x_ind = self.down4_conv_bn(x_ind)
        x_ind = F.leaky_relu(x_ind)

        x_ind = self.Unmp2d(x_ind, index)

        x_ind = self.down3_conv(x_ind)
        x_ind = self.down3_conv_bn(x_ind)
        x_ind = F.leaky_relu(x_ind)

        x_ind = self.down5_conv(x_ind)
        x_ind = self.down5_conv_bn(x_ind)
        x_ind = F.leaky_relu(x_ind)

        x_ind = self.upSample(x_ind)
        x_ind = F.leaky_relu(x_ind)

        indexing_tensor = x_ind.narrow(1, 0,
                                       self.output_channels)  # [:,:49,:,:]
        indexing_tensor = torch.exp(F.log_softmax(indexing_tensor, 1))
        indexing_pred, indexing_argmax = indexing_tensor.max(dim=1)

        ## skeleton model:

        x1 = self.down0_convSkel(x)
        x2 = self.down0_conv_bnSkel(x1)
        x3 = F.relu(self.down1_convSkel(x2))
        x4 = self.down1_conv_bnSkel(x3)
        x5 = F.relu(x4)
        x6 = self.down2_convSkel(x5)
        x7 = F.relu(x6)

        x8_1, index = self.mp2dSkel(x7)
        x8_2 = self.Unmp2dSkel(x8_1, index)
        x9_1, index = self.mp2dSkel2(x7)
        x9_2 = self.Unmp2dSkel2(x9_1, index)
        x8_2[0, 0, 0:10, :] = x9_2[0, 0, 0:10, :]
        x8_2[0, 0, 30:50, :] = x9_2[0, 0, 30:50, :]
        x8_2[0, 0, 70:90, :] = x9_2[0, 0, 70:90, :]
        x8_2[0, 0, 110:130, :] = x9_2[0, 0, 110:130, :]
        x8_2[0, 0, 150:170, :] = x9_2[0, 0, 150:170, :]
        x8_2[0, 0, 190:210, :] = x9_2[0, 0, 190:210, :]
        x8_2[0, 0, 230:250, :] = x9_2[0, 0, 230:250, :]
        x8_2[0, 0, 270:290, :] = x9_2[0, 0, 270:290, :]
        x8_2[0, 0, 270:290, :] = x9_2[0, 0, 270:290, :]
        x8_2[0, 0, 310:330, :] = x9_2[0, 0, 310:330, :]
        x8_2[0, 0, 350:370, :] = x9_2[0, 0, 350:370, :]
        x8_2[0, 0, 390:410, :] = x9_2[0, 0, 390:410, :]
        x8_2[0, 0, 430:450, :] = x9_2[0, 0, 430:450, :]
        x8_2[0, 0, 470:490, :] = x9_2[0, 0, 470:490, :]
        x8_2[0, 0, 510:530, :] = x9_2[0, 0, 510:530, :]
        x8_2[0, 0, 550:570, :] = x9_2[0, 0, 550:570, :]
        x8_2[0, 0, 590:610, :] = x9_2[0, 0, 590:610, :]
        x8_2[0, 0, 630:640, :] = x9_2[0, 0, 630:640, :]

        skeleton_pred = x7
        skeleton_final = x8_2
        final_res = torch.clone(indexing_argmax)
        final_res[0, skeleton_final[0, 0, :, :] == 0] = 0

        return final_res, skeleton_final, skeleton_pred, indexing_tensor, indexing_argmax, indexing_pred
Ejemplo n.º 16
0
    def from_torch(attention: TorchBertAttention,
                   layer_norm: Optional[TorchLayerNorm] = None,
                   is_trans_weight: bool = False):
        """
        load an attn model from huggingface bert attention model.
        """
        ln_params = {}
        if layer_norm is not None:
            ln_params = {k: v for k, v in layer_norm.named_parameters()}
        params = {k: v for k, v in attention.named_parameters()}
        with torch.no_grad():
            if is_trans_weight:
                # merge self.query.weight, self.query.weight and self.query.weight together as qkv.weight
                qkv_weight = torch.cat(
                    (params['self.query.weight'], params['self.key.weight'],
                     params['self.value.weight']), 0)
                output_weight = params['output.dense.weight']
                k_w = params['self.key.weight']
                v_w = params['self.value.weight']
                q_w = params['self.query.weight']
            else:
                # merge self.query.weight, self.query.weight and self.query.weight together as qkv.weight
                qkv_weight = torch.clone(
                    torch.t(
                        torch.cat((params['self.query.weight'],
                                   params['self.key.weight'],
                                   params['self.value.weight']),
                                  0).contiguous()).contiguous())
                output_weight = torch.clone(
                    torch.t(params['output.dense.weight']).contiguous())
                k_w = torch.clone(
                    torch.t(params['self.key.weight']).contiguous())
                v_w = torch.clone(
                    torch.t(params['self.value.weight']).contiguous())
                q_w = torch.clone(
                    torch.t(params['self.query.weight']).contiguous())

            qkv_bias = torch.cat(
                (params['self.query.bias'], params['self.key.bias'],
                 params['self.value.bias']), 0)

            if layer_norm is not None:
                att = MultiHeadedAttention(
                    convert2tt_tensor(k_w),
                    convert2tt_tensor(params['self.key.bias']),
                    convert2tt_tensor(v_w),
                    convert2tt_tensor(params['self.value.bias']),
                    convert2tt_tensor(q_w),
                    convert2tt_tensor(params['self.query.bias']),
                    convert2tt_tensor(output_weight),
                    convert2tt_tensor(params['output.dense.bias']),
                    convert2tt_tensor(qkv_weight), convert2tt_tensor(qkv_bias),
                    convert2tt_tensor(params['output.LayerNorm.weight']),
                    convert2tt_tensor(params['output.LayerNorm.bias']),
                    convert2tt_tensor(ln_params['weight']),
                    convert2tt_tensor(ln_params['bias']),
                    attention.self.num_attention_heads)
            else:
                att = MultiHeadedAttention(
                    convert2tt_tensor(k_w),
                    convert2tt_tensor(params['self.key.bias']),
                    convert2tt_tensor(v_w),
                    convert2tt_tensor(params['self.value.bias']),
                    convert2tt_tensor(q_w),
                    convert2tt_tensor(params['self.query.bias']),
                    convert2tt_tensor(output_weight),
                    convert2tt_tensor(params['output.dense.bias']),
                    convert2tt_tensor(qkv_weight), convert2tt_tensor(qkv_bias),
                    convert2tt_tensor(params['output.LayerNorm.weight']),
                    convert2tt_tensor(params['output.LayerNorm.bias']),
                    attention.self.num_attention_heads)
            return att
Ejemplo n.º 17
0
    def train(self, G, deformator, shift_predictor):
        G.cuda().eval()
        deformator.cuda().train()
        shift_predictor.cuda().train()

        deformator_opt = torch.optim.Adam(deformator.parameters(), lr=self.p.deformator_lr) \
            if deformator.type not in [DeformatorType.ID, DeformatorType.RANDOM] else None
        shift_predictor_opt = torch.optim.Adam(shift_predictor.parameters(),
                                               lr=self.p.shift_predictor_lr)

        avgs = MeanTracker('percent'), MeanTracker('loss'), MeanTracker('direction_loss'),\
               MeanTracker('shift_loss'), MeanTracker('deformator_loss')
        avg_correct_percent, avg_loss, avg_label_loss, avg_shift_loss, avg_deformator_loss = avgs

        recovered_step = self.start_from_checkpoint(deformator,
                                                    shift_predictor)
        for step in range(recovered_step, self.p.n_steps, 1):
            G.zero_grad()
            deformator.zero_grad()
            shift_predictor.zero_grad()

            z = make_noise(self.p.batch_size, G.dim_z).cuda()
            z_orig = torch.clone(z)
            target_indices, shifts, z_shift = self.make_shifts(G.dim_z)

            # Deformation

            if self.p.global_deformation:
                z_shifted = deformator(z + z_shift)
                z = deformator(z)
            else:
                z_shifted = z + deformator(z_shift)
            imgs = G(z)
            imgs_shifted = G(z_shifted)

            logits, shift_prediction = shift_predictor(imgs, imgs_shifted)
            logit_loss = self.p.label_weight * self.cross_entropy(
                logits, target_indices)
            shift_loss = self.p.shift_weight * torch.mean(
                torch.abs(shift_prediction - shifts))

            # Loss

            # deformator penalty
            if self.p.deformation_loss == DeformatorLoss.STAT:
                z_std, z_mean = normal_projection_stat(z)
                z_loss = self.p.z_mean_weight * torch.abs(z_mean) + \
                    self.p.z_std_weight * torch.abs(1.0 - z_std)

            elif self.p.deformation_loss == DeformatorLoss.L2:
                z_loss = self.p.deformation_loss_weight * torch.mean(
                    torch.norm(z, dim=1))
                if z_loss < self.p.z_norm_loss_low_bound * torch.mean(
                        torch.norm(z_orig, dim=1)):
                    z_loss = torch.tensor([0.0], device='cuda')

            elif self.p.deformation_loss == DeformatorLoss.RELATIVE:
                deformation_norm = torch.norm(z - z_shifted, dim=1)
                z_loss = self.p.deformation_loss_weight * torch.mean(
                    torch.abs(deformation_norm - shifts))

            else:
                z_loss = torch.tensor([0.0], device='cuda')

            # total loss
            loss = logit_loss + shift_loss + z_loss
            loss.backward()

            if deformator_opt is not None:
                deformator_opt.step()
            shift_predictor_opt.step()

            # update statistics trackers
            avg_correct_percent.add(
                torch.mean((torch.argmax(logits, dim=1) == target_indices).to(
                    torch.float32)).detach())
            avg_loss.add(loss.item())
            avg_label_loss.add(logit_loss.item())
            avg_shift_loss.add(shift_loss)
            avg_deformator_loss.add(z_loss.item())

            self.log(G, deformator, shift_predictor, step, avgs)
Ejemplo n.º 18
0
def cubo2rod(cu):
    """Cubochoric vector to Rodrigues-Frank vector.
        Quaternion returned in form s, <x,y,z>
        where s is real component, and <x,y,z>
        is imaginary vector component
    """
    """
        Step 1: Cubochoric vector to homochoric vector.

        References
        ----------
        D. Roşca et al., Modelling and Simulation in Materials Science and Engineering 22:075013, 2014
        https://doi.org/10.1088/0965-0393/22/7/075013

        """

    # get pyramid and scale by grid parameter ratio
    XYZ = torch.gather(cu, -1, _get_tensor_pyramid_order(cu, 'forward')) * sc
    order = torch.le(torch.abs(XYZ[..., 1:2]), torch.abs(XYZ[..., 0:1]))
    q = math.pi / 12.0 * torch.where(order, XYZ[..., 1:2], XYZ[..., 0:1]) \
        / torch.where(order, XYZ[..., 0:1], XYZ[..., 1:2])
    c = torch.cos(q)
    s = torch.sin(q)
    q = R1 * 2.0 ** 0.25 / beta / torch.sqrt(math.sqrt(2.0) - c) \
        * torch.where(order, XYZ[..., 0:1], XYZ[..., 1:2])

    T = torch.cat(((math.sqrt(2.0) * c - 1.0), math.sqrt(2.0) * s), dim=-1) * q

    # transform to sphere grid (inverse Lambert)
    c = torch.sum(T**2, -1, keepdim=True)
    s = c * math.pi / 24.0 / XYZ[..., 2:3]**2
    c = c * math.sqrt(math.pi / 24.0) / XYZ[..., 2:3]
    q = torch.sqrt(1.0 - s)

    ho = torch.where(
        torch.isclose(torch.sum(torch.abs(XYZ[..., 0:2]), -1, keepdim=True),
                      _precision_check(0.0, XYZ.dtype),
                      rtol=0.0,
                      atol=1.0e-16),
        torch.cat((torch.zeros_like(
            XYZ[..., 0:2]), math.sqrt(6.0 / math.pi) * XYZ[..., 2:3]),
                  dim=-1),
        torch.cat((torch.where(order, T[..., 0:1], T[..., 1:2]) * q,
                   torch.where(order, T[..., 1:2], T[..., 0:1]) * q,
                   math.sqrt(6.0 / math.pi) * XYZ[..., 2:3] - c),
                  dim=-1))

    ho[torch.isclose(torch.sum(torch.abs(cu), -1),
                     _precision_check(0.0, cu.dtype),
                     rtol=0.0,
                     atol=1.0e-16)] = 0.0  # warning
    ho = torch.gather(ho, -1, _get_tensor_pyramid_order(cu, 'backward'))

    # return ho # here for homochoric
    """Step 2: Homochoric vector to axis angle pair."""
    tfit = [
        +1.0000000000018852, -0.5000000002194847, -0.024999992127593126,
        -0.003928701544781374, -0.0008152701535450438, -0.0002009500426119712,
        -0.00002397986776071756, -0.00008202868926605841,
        +0.00012448715042090092, -0.0001749114214822577,
        +0.0001703481934140054, -0.00012062065004116828,
        +0.000059719705868660826, -0.00001980756723965647,
        +0.000003953714684212874, -0.00000036555001439719544
    ]
    hmag_squared = torch.sum(ho**2., -1, keepdim=True)

    hm = torch.clone(
        hmag_squared)  # use detach() for decoupled autograd relationship

    s = tfit[0] + tfit[1] * hmag_squared
    for i in range(2, 16):
        hm *= hmag_squared
        s += tfit[i] * hm

    # with np.errstate(invalid='ignore'):
    ax = torch.where(
        torch.lt(torch.abs(hmag_squared),
                 torch.tensor(1.e-8)).expand(ho.shape[:-1] + (4, )),
        _precision_check([0.0, 0.0, 1.0, 0.0], ho.dtype, ho.device),
        torch.cat((ho / torch.sqrt(hmag_squared),
                   2.0 * torch.arccos(torch.clip(s, -1.0, 1.0))),
                  dim=-1))

    # return ax # here for axis angle pair
    """Step 3: Axis angle pair to Rodrigues-Frank vector."""
    ro = torch.cat((ax[..., :3],
                    torch.where(
                        torch.isclose(ax[..., 3:4],
                                      _precision_check(math.pi, ax.dtype),
                                      atol=1.e-15,
                                      rtol=.0),
                        _precision_check(float('inf'), ax.dtype, ax.device),
                        torch.tan(ax[..., 3:4] * 0.5))),
                   dim=-1)
    ro[torch.lt(torch.abs(ax[..., 3]),
                1.e-6)] = _precision_check([.0, .0, P, .0], ax.dtype,
                                           ax.device)

    return ro
Ejemplo n.º 19
0
def powerSGD_hook(state: PowerSGDState, bucket) -> torch.futures.Future:
    """
    This DDP communication hook implements the original PowerSGD gradient compression
    algorithm described in https://arxiv.org/abs/1905.13727.
    Once gradient tensors are aggregated across all workers, this hook applies
    compression as follows:
    1) Views the input flattened 1D gradient tensor as two groups of per-parameter tensors:
    high-rank tensors and vector-like rank-1 tensors (for biases).
    2) Handles rank-1 tensors by allreducing them without compression:
        2.1) Allocate contiguous memory for those rank-1 tensors,
        and allreduces all the rank-1 tensors as a batch, without compression;
        2.2) Copies the individual rank-1 tensors from the contiguous memory back to the input tensor.
    3) Handles high-rank tensors by PowerSGD compression:
        3.1) For each high-rank tensor M, creates two low-rank tensors P and Q for decomposing M,
        such that M = PQ^T, where Q is initialized from a standard normal distribution and orthogonalized;
        3.2) Computes each P in Ps, which is equal to MQ;
        3.3) Allreduces Ps as a batch;
        3.4) Orthogonalizes each P in Ps;
        3.5) Computes each Q in Qs, which is approximately equal to M^TP;
        3.6) Allreduces Qs as a batch;
        3.7) Computes each M among all the high-rank tensors, which is approximately equal to PQ^T.

    Note that this communication hook enforces vanilla allreduce for the first `state.start_powerSGD_iter` iterations.
    This can not only allow the user to have a finer tuning over the tradeoff between speedup and accuracy,
    but also help abstract away some complexity of the internal optimization of DDP for future communication hook developers.

    TODO(wayi@): The above procedure does two matmul+allreduce steps per iteration --
    one left multiplication and one right multiplication.
    For warm-start, can take one such step at a time, and alternate between them.

    Args:
        state (PowerSGDState): State information to configure the compression rate and support error feedback, warm start, etc.
            To tune the compression configs, see Note [Guidance to Tune `matrix_approximation_rank` And `start_powerSGD_iter`].
        bucket (dist._GradBucket): Bucket that stores a 1D flattened gradient tensor that batches multiple per-variable tensors.
            Note that since DDP comm hook only supports single process single device mode at this time,
            only exactly one tensor is stored in this bucket.

    Returns:
        Future handler of the communication, which updates the gradients in place.

    Example::
        state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1, start_powerSGD_iter=10)
        >>> ddp_model.register_comm_hook(state, powerSGD_hook)
    """
    process_group = state.process_group
    group_to_use = process_group if process_group is not None else dist.group.WORLD
    world_size = group_to_use.size()

    # The input tensor is a flattened 1D tensor.
    input_tensor = bucket.get_tensors()[0]

    # Run vanilla allreduce in the first `start_powerSGD_iter` iterations.
    if state.iter < state.start_powerSGD_iter:
        state.maybe_increase_iter(bucket)
        return default._allreduce_fut(group_to_use, input_tensor)

    # Apply PowerSGD after `start_powerSGD_iter` iterations.
    device = input_tensor.device
    dtype = input_tensor.dtype

    # Incorporate the error from the previous state into the gradients.
    bucket_index = bucket.get_index()
    input_tensor_cp = None
    total_length = input_tensor.shape[0]
    if state.use_error_feedback:
        if bucket_index in state.error_dict:
            input_tensor.add_(state.error_dict[bucket_index])
        else:
            logging.info(
                "A zero tensor of length {} that represents local error is created."
                .format(total_length))
            state.error_dict[bucket_index] = torch.zeros(total_length,
                                                         device=device,
                                                         dtype=dtype)

        # Keep a copy of the input tensor,
        # so that we can compute the local error caused by compression later,
        # by comparing this copy and the input tensor updated after decompression.
        input_tensor_cp = torch.clone(input_tensor).detach()

    # Unflatten the input tensor into per-parameter tensors, for layer-wise compression.
    tensors = [
        input_tensor[offset:offset + length].view(sizes)
        for offset, length, sizes in zip(bucket.get_offsets(
        ), bucket.get_lengths(), bucket.get_sizes_list())
    ]

    # Step I: Handle rank-1 tensors.
    # Allocate contiguous memory for rank-1 tensors to allreduce them without compression efficiently.
    rank1_tensors = [tensor for tensor in tensors if tensor.ndimension() <= 1]
    rank1_tensors_memory = (torch.cat([
        tensor.view(-1) for tensor in rank1_tensors
    ]) if rank1_tensors else torch.tensor([], device=device, dtype=dtype))

    # Step II: Handle high-rank tensors.
    # Allocate contiguous memory for Ps and Qs to allreduce compressed high-rank tensors efficiently.
    high_rank_tensors = [
        tensor.view(tensor.shape[0], -1) for tensor in tensors
        if tensor.ndimension() > 1
    ]
    total_Ps_size = 0
    total_Qs_size = 0
    for tensor in high_rank_tensors:
        n, m = tensor.shape
        matrix_approximation_rank = min(n, m, state.matrix_approximation_rank)
        total_Ps_size += n * matrix_approximation_rank
        total_Qs_size += m * matrix_approximation_rank
    # If warm-start is enabled, reuse Ps and Qs from the previous iteration if possible.
    # The memory spaces of Ps and Qs need to be allocated in the first iteration when PowerSGD is applied.
    need_randomize_qs = False
    if not state.warm_start or bucket_index not in state.p_memory_dict:
        need_randomize_qs = True
        # If warm-start is disabled, low-rank tensors will be initialized at every step.
        # Only log this if warm-start to avoid spamming.
        if state.warm_start:
            logging.info(
                "Allocating contiguous memory of length {} for Ps, and of length {} for Qs, respectively."
                .format(total_Ps_size, total_Qs_size))
        state.p_memory_dict[bucket_index] = torch.empty(total_Ps_size,
                                                        device=device,
                                                        dtype=dtype)
        state.q_memory_dict[bucket_index] = torch.empty(total_Qs_size,
                                                        device=device,
                                                        dtype=dtype)

    # Create Ps and Qs that point to the allocated memory.
    ps = []
    qs = []
    p_idx = 0
    q_idx = 0
    for tensor in high_rank_tensors:
        n, m = tensor.shape
        matrix_approximation_rank = min(n, m, state.matrix_approximation_rank)
        ps.append(
            state.p_memory_dict[bucket_index][p_idx:p_idx + n *
                                              matrix_approximation_rank].view(
                                                  n,
                                                  matrix_approximation_rank))
        qs.append(
            state.q_memory_dict[bucket_index][q_idx:q_idx + m *
                                              matrix_approximation_rank].view(
                                                  m,
                                                  matrix_approximation_rank))
        p_idx += n * matrix_approximation_rank
        q_idx += m * matrix_approximation_rank

    # If warm-start is enabled, reuse Qs from the previous iteration if possible and skip filling random values.
    # The exception is the first iteration when PowerSGD is applied.
    if not need_randomize_qs:
        for q in qs:
            _orthogonalize(q)
    else:
        with torch.random.fork_rng(devices=[]):
            # Fork this RNG to avoid changing the seed globally and affecting the random sampling anywhere else in the training.
            # The seed makes sure that the initial random values are the same across all the DDP replicas.
            # This seed should differ at every step.
            # Since it is very slow to fork RNG state across all the CUDA devices,
            # only fork on CPU and then move the generated tensor to the CUDA device (by overwriting q).
            torch.manual_seed(state.rng.randint(1_000_000_000))
            for q in qs:
                q.copy_(torch.randn(
                    *q.shape,
                    device="cpu",
                    dtype=dtype,
                ))
                _orthogonalize(q)

    # Compute Ps.
    for tensor, q, p in zip(high_rank_tensors, qs, ps):
        torch.matmul(tensor, q, out=p)

    # This allreduce is only applied to rank-1 tensors,
    # so it should have been kicked off before the above computation on the high-rank tensors to hide more communication costs.
    # However, this somehow requires a separate future chain at this time.
    allreduce_contiguous_rank1_tensors_fut = dist.all_reduce(
        rank1_tensors_memory, group=group_to_use, async_op=True).get_future()

    def unpack_rank1_tensors_and_allreduce_ps(fut):
        rank1_tensors_memory = fut.value()[0].div_(world_size)
        idx = 0
        for tensor in rank1_tensors:
            tensor.copy_(rank1_tensors_memory[idx:idx + tensor.shape[0]])
            idx += tensor.shape[0]

        # Since these Ps will be orthogonalized later, no need to divide them by world size.
        return [
            dist.all_reduce(state.p_memory_dict[bucket_index],
                            group=group_to_use,
                            async_op=True).get_future().wait()[0]
        ]

    def compute_qs(fut):
        state.p_memory_dict[bucket_index] = fut.value()[0]
        for p in ps:
            _orthogonalize(p)

        # Compute Qs.
        for tensor, p, q in zip(high_rank_tensors, ps, qs):
            torch.matmul(tensor.t(), p, out=q)

        # Allreduce Qs.
        return [
            dist.all_reduce(state.q_memory_dict[bucket_index],
                            group=group_to_use,
                            async_op=True).get_future().wait()[0]
        ]

    def decompress(fut):
        state.q_memory_dict[bucket_index] = fut.value()[0].div_(world_size)

        for p, q, tensor in zip(ps, qs, high_rank_tensors):
            torch.matmul(p, q.t(), out=tensor)
        if torch.cuda.is_available():
            torch.cuda.synchronize(device)

        if state.use_error_feedback:
            # Memorize the local errors.
            state.error_dict[bucket_index] = input_tensor_cp - input_tensor
        if not state.warm_start:
            state.p_memory_dict.clear()
            state.q_memory_dict.clear()

        state.maybe_increase_iter(bucket)

        return [input_tensor]

    return (allreduce_contiguous_rank1_tensors_fut.then(
        unpack_rank1_tensors_and_allreduce_ps).then(compute_qs).then(
            decompress))
Ejemplo n.º 20
0
def evaluate(args, model, criterions, dataloader):
    model.eval()
    epoch_loss = 0
    n_class = 12
    example_images = []
    with torch.no_grad():
        hist = np.zeros((n_class, n_class))
        miou_images = []
        for images, masks, _ in dataloader:

            images = torch.stack(images)  # (batch, channel, height, width)
            masks = torch.stack(
                masks).long()  # (batch, channel, height, width)

            images, masks = images.to(args.device), masks.to(args.device)

            outputs = model(images)
            flag = criterions[0]
            if flag == "+":
                loss = criterions[1](outputs, masks) + criterions[2](outputs,
                                                                     masks)
            elif flag == "-":
                loss = criterions[1](outputs, masks) - criterions[2](outputs,
                                                                     masks)
            else:
                loss = criterions[1](outputs, masks)
            epoch_loss += loss

            inputs_np = torch.clone(images).detach().cpu().permute(0, 2, 3,
                                                                   1).numpy()
            inputs_np = denormalize_image(inputs_np,
                                          mean=(0.4611, 0.4403, 0.4193),
                                          std=(0.2107, 0.2074, 0.2157))

            example_images.append(
                wb_mask(
                    inputs_np[0],
                    pred_mask=outputs.argmax(1)[0].detach().cpu().numpy(),
                    true_mask=masks[0].detach().cpu().numpy(),
                ))

            outputs = torch.argmax(outputs.squeeze(),
                                   dim=1).detach().cpu().numpy()

            hist = add_hist(hist,
                            masks.detach().cpu().numpy(),
                            outputs,
                            n_class=n_class)

            # 이미지별 miou 저장
            miou_list = get_miou(masks.detach().cpu().numpy(),
                                 outputs,
                                 n_class=n_class)
            miou_images.extend(miou_list)

        # metrics
        acc, acc_cls, miou, fwavacc = label_accuracy_score(hist)

        # 리더보드 miou
        lb_miou = np.nanmean(miou_images)

        print(f"acc:{acc:.4f}, acc_cls:{acc_cls:.4f}, fwavacc:{fwavacc:.4f}")

        # hist wandb에 저장
        summa = hist.sum(1).reshape(-1, 1)
        percent = hist / summa
        plt.figure(figsize=(10, 10))
        sns.heatmap(percent, annot=True, fmt=".2%", annot_kws={"size": 8})
        wandb.log({"percent_hist": wandb.Image(plt)}, commit=False)

    return (epoch_loss / len(dataloader)), lb_miou, miou, example_images
Ejemplo n.º 21
0
    def sample(self):
        if self.config.ncsn.sampling.ckpt_id is None:
            ncsn_states = torch.load(os.path.join(
                'scones', self.config.ncsn.sampling.log_path,
                'checkpoint.pth'),
                                     map_location=self.config.device)
        else:
            ncsn_states = torch.load(os.path.join(
                'scones', self.config.ncsn.sampling.log_path,
                f'checkpoint_{self.config.ncsn.sampling.ckpt_id}.pth'),
                                     map_location=self.config.device)

        score = get_scorenet(self.config)
        score = torch.nn.DataParallel(score)

        sigmas_th = get_sigmas(self.config.ncsn)
        sigmas = sigmas_th.cpu().numpy()

        if ("module.sigmas" in ncsn_states[0].keys()):
            ncsn_states[0]["module.sigmas"] = sigmas_th

        score.load_state_dict(ncsn_states[0], strict=True)
        score.eval()

        baryproj_data_init = (hasattr(self.config, "baryproj")
                              and self.config.ncsn.sampling.data_init)

        if (baryproj_data_init):
            if (self.config.baryproj.ckpt_id is None):
                bproj_states = torch.load(os.path.join(
                    'scones', self.config.baryproj.log_path, 'checkpoint.pth'),
                                          map_location=self.config.device)
            else:
                bproj_states = torch.load(os.path.join(
                    'scones', self.config.baryproj.log_path,
                    f'checkpoint_{self.config.baryproj.ckpt_id}.pth'),
                                          map_location=self.config.device)

            bproj = get_bary(self.config)
            bproj.load_state_dict(bproj_states[0])
            bproj = torch.nn.DataParallel(bproj)
            bproj.eval()

        if self.config.compatibility.ckpt_id is None:
            cpat_states = torch.load(os.path.join(
                'scones', self.config.compatibility.log_path,
                'checkpoint.pth'),
                                     map_location=self.config.device)
        else:
            cpat_states = torch.load(os.path.join(
                'scones', self.config.compatibility.log_path,
                f'checkpoint_{self.config.compatibility.ckpt_id}.pth'),
                                     map_location=self.config.device)

        cpat = get_compatibility(self.config)
        cpat.load_state_dict(cpat_states[0])

        if self.config.ncsn.model.ema:
            ema_helper = EMAHelper(mu=self.config.ncsn.model.ema_rate)
            ema_helper.register(score)
            ema_helper.load_state_dict(ncsn_states[-1])
            ema_helper.ema(score)

        source_dataset, _ = get_dataset(self.args, self.config.source)
        dataloader = DataLoader(
            source_dataset,
            batch_size=self.config.ncsn.sampling.sources_per_batch,
            shuffle=True,
            num_workers=self.config.source.data.num_workers)
        data_iter = iter(dataloader)

        (Xs, labels) = next(data_iter)
        Xs_global = torch.cat([Xs] *
                              self.config.ncsn.sampling.samples_per_source,
                              dim=0).to(self.config.device)
        Xs_global = data_transform(self.config.source, Xs_global)

        if (hasattr(self.config.ncsn.sampling, "n_sigmas_skip")):
            n_sigmas_skip = self.config.ncsn.sampling.n_sigmas_skip
        else:
            n_sigmas_skip = 0

        if not self.config.ncsn.sampling.fid:
            if self.config.ncsn.sampling.inpainting:
                ''' NCSN INPAINTING CODE. EITHER PATCH THIS FOR SCONES OR REMOVE IT. 
                
                data_iter = iter(dataloader)
                refer_images, _ = next(data_iter)
                refer_images = refer_images.to(self.config.device)
                width = int(np.sqrt(self.config.sampling.batch_size))
                init_samples = torch.rand(width, width, self.config.data.channels,
                                          self.config.data.image_size,
                                          self.config.data.image_size,
                                          device=self.config.device)
                init_samples = data_transform(self.config, init_samples)
                all_samples = anneal_Langevin_dynamics_inpainting(init_samples, refer_images[:width, ...], score,
                                                                  sigmas,
                                                                  self.config.data.image_size,
                                                                  self.config.sampling.n_steps_each,
                                                                  self.config.sampling.step_lr)

                torch.save(refer_images[:width, ...], os.path.join(self.args.image_folder, 'refer_image.pth'))
                refer_images = refer_images[:width, None, ...].expand(-1, width, -1, -1, -1).reshape(-1,
                                                                                                     *refer_images.shape[
                                                                                                      1:])
                save_image(refer_images, os.path.join(self.args.image_folder, 'refer_image.png'), nrow=width)

                if not self.config.sampling.final_only:
                    for i, sample in enumerate(tqdm.tqdm(all_samples)):
                        sample = sample.view(self.config.sampling.batch_size, self.config.data.channels,
                                             self.config.data.image_size,
                                             self.config.data.image_size)

                        sample = inverse_data_transform(self.config, sample)

                        image_grid = make_grid(sample, int(np.sqrt(self.config.sampling.batch_size)))
                        save_image(image_grid, os.path.join(self.args.image_folder, 'image_grid_{}.png'.format(i)))
                        torch.save(sample, os.path.join(self.args.image_folder, 'completion_{}.pth'.format(i)))
                else:
                    sample = all_samples[-1].view(self.config.sampling.batch_size, self.config.data.channels,
                                                  self.config.data.image_size,
                                                  self.config.data.image_size)

                    sample = inverse_data_transform(self.config, sample)

                    image_grid = make_grid(sample, int(np.sqrt(self.config.sampling.batch_size)))
                    save_image(image_grid, os.path.join(self.args.image_folder,
                                                        'image_grid_{}.png'.format(self.config.ncsn.sampling.ckpt_id)))
                    torch.save(sample, os.path.join(self.args.image_folder,
                                                    'completion_{}.pth'.format(self.config.sampling.ckpt_id)))
                '''

                raise NotImplementedError(
                    "Inpainting with SCONES is not currently implemented.")
            elif self.config.ncsn.sampling.interpolation:
                ''' NCSN INTERPOLATION CODE. EITHER PATCH THIS FOR SCONES OR REMOVE IT. 
                
                if self.config.sampling.data_init:
                    data_iter = iter(dataloader)
                    samples, _ = next(data_iter)
                    samples = samples.to(self.config.device)
                    samples = data_transform(self.config, samples)
                    init_samples = samples + sigmas_th[0] * torch.randn_like(samples)

                else:
                    init_samples = torch.rand(self.config.sampling.batch_size, self.config.data.channels,
                                              self.config.data.image_size, self.config.data.image_size,
                                              device=self.config.device)
                    init_samples = data_transform(self.config, init_samples)

                all_samples = anneal_Langevin_dynamics_interpolation(init_samples, score, sigmas,
                                                                     self.config.sampling.n_interpolations,
                                                                     self.config.sampling.n_steps_each,
                                                                     self.config.sampling.step_lr, verbose=True,
                                                                     final_only=self.config.sampling.final_only)

                if not self.config.sampling.final_only:
                    for i, sample in tqdm.tqdm(enumerate(all_samples), total=len(all_samples),
                                               desc="saving image samples"):
                        sample = sample.view(sample.shape[0], self.config.data.channels,
                                             self.config.data.image_size,
                                             self.config.data.image_size)

                        sample = inverse_data_transform(self.config, sample)

                        image_grid = make_grid(sample, nrow=self.config.sampling.n_interpolations)
                        save_image(image_grid, os.path.join(self.args.image_folder, 'image_grid_{}.png'.format(i)))
                        torch.save(sample, os.path.join(self.args.image_folder, 'samples_{}.pth'.format(i)))
                else:
                    sample = all_samples[-1].view(all_samples[-1].shape[0], self.config.data.channels,
                                                  self.config.data.image_size,
                                                  self.config.data.image_size)

                    sample = inverse_data_transform(self.config, sample)

                    image_grid = make_grid(sample, self.config.sampling.n_interpolations)
                    save_image(image_grid, os.path.join(self.args.image_folder,
                                                        'image_grid_{}.png'.format(self.config.sampling.ckpt_id)))
                    torch.save(sample, os.path.join(self.args.image_folder,
                                                    'samples_{}.pth'.format(self.config.sampling.ckpt_id)))
                '''
                raise NotImplementedError(
                    "Interpolation with SCONES is not currently implemented.")
            else:
                if self.config.ncsn.sampling.data_init:
                    if (baryproj_data_init):
                        with torch.no_grad():
                            init_Xt = (bproj(Xs_global) +
                                       sigmas_th[n_sigmas_skip] *
                                       torch.randn_like(Xs_global)).detach()
                    else:
                        init_Xt = Xs_global + sigmas_th[
                            n_sigmas_skip] * torch.randn_like(Xs_global)

                    init_Xt.requires_grad = True
                    init_Xt = init_Xt.to(self.config.device)

                else:
                    init_Xt = torch.rand(
                        self.config.ncsn.sampling.sources_per_batch *
                        self.config.ncsn.sampling.samples_per_source,
                        self.config.target.data.channels,
                        self.config.target.data.image_size,
                        self.config.target.data.image_size,
                        device=self.config.device)
                    init_Xt = data_transform(self.config.target, init_Xt)
                    init_Xt.requires_grad = True
                    init_Xt = init_Xt.to(self.config.device)

                all_samples = anneal_Langevin_dynamics(
                    init_Xt,
                    Xs_global,
                    score,
                    cpat,
                    sigmas,
                    self.config.ncsn.sampling.n_steps_each,
                    self.config.ncsn.sampling.step_lr,
                    verbose=True,
                    final_only=self.config.ncsn.sampling.final_only,
                    denoise=self.config.ncsn.sampling.denoise,
                    n_sigmas_skip=n_sigmas_skip)

                all_samples = torch.stack(all_samples, dim=0)

                if not self.config.ncsn.sampling.final_only:
                    all_samples = all_samples.view(
                        (-1, self.config.ncsn.sampling.sources_per_batch,
                         self.config.ncsn.sampling.samples_per_source,
                         self.config.target.data.channels,
                         self.config.target.data.image_size,
                         self.config.target.data.image_size))
                    np.save(
                        os.path.join(self.args.image_folder,
                                     'all_samples.npy'),
                        all_samples.detach().cpu().numpy())

                sample = all_samples[-1].view(
                    self.config.ncsn.sampling.sources_per_batch *
                    self.config.ncsn.sampling.samples_per_source,
                    self.config.target.data.channels,
                    self.config.target.data.image_size,
                    self.config.target.data.image_size)

                sample = inverse_data_transform(self.config.target, sample)

                image_grid = make_grid(
                    sample, nrow=self.config.ncsn.sampling.sources_per_batch)
                save_image(
                    image_grid,
                    os.path.join(self.args.image_folder, 'sample_grid.png'))

                source_grid = make_grid(
                    Xs, nrow=self.config.ncsn.sampling.sources_per_batch)
                save_image(
                    source_grid,
                    os.path.join(self.args.image_folder, 'source_grid.png'))

                bproj_of_source = make_grid(
                    bproj(Xs),
                    nrow=self.config.ncsn.sampling.sources_per_batch)
                save_image(
                    bproj_of_source,
                    os.path.join(self.args.image_folder, 'bproj_sources.png'))

                np.save(os.path.join(self.args.image_folder, 'sources.npy'),
                        Xs.detach().cpu().numpy())
                np.save(
                    os.path.join(self.args.image_folder, 'source_labels.npy'),
                    labels.detach().cpu().numpy())
                np.save(os.path.join(self.args.image_folder, 'bproj.npy'),
                        bproj(Xs).detach().cpu().numpy())
                np.save(os.path.join(self.args.image_folder, 'samples.npy'),
                        sample.detach().cpu().numpy())

        else:
            batch_size = self.config.ncsn.sampling.sources_per_batch * self.config.ncsn.sampling.samples_per_source
            total_n_samples = self.config.ncsn.sampling.num_samples4fid
            n_rounds = total_n_samples // batch_size
            if self.config.ncsn.sampling.data_init:
                dataloader = DataLoader(
                    source_dataset,
                    batch_size=self.config.ncsn.sampling.sources_per_batch,
                    shuffle=True,
                    num_workers=self.config.source.data.num_workers)
                data_iter = iter(dataloader)

            img_id = 0
            for r in tqdm.tqdm(
                    range(n_rounds),
                    desc=
                    'Generating image samples for FID/inception score evaluation'
            ):
                if self.config.ncsn.sampling.data_init:
                    try:
                        init_samples, labels = next(data_iter)
                        init_samples = torch.cat(
                            [init_samples] *
                            self.config.ncsn.sampling.samples_per_source,
                            dim=0)
                        labels = torch.cat(
                            [labels] *
                            self.config.ncsn.sampling.samples_per_source,
                            dim=0)
                    except StopIteration:
                        data_iter = iter(dataloader)
                        init_samples, labels = next(data_iter)
                        init_samples = torch.cat(
                            [init_samples] *
                            self.config.ncsn.sampling.samples_per_source,
                            dim=0)
                        labels = torch.cat(
                            [labels] *
                            self.config.ncsn.sampling.samples_per_source,
                            dim=0)

                    init_samples = init_samples.to(self.config.device)
                    init_samples = data_transform(self.config.target,
                                                  init_samples)

                    if (baryproj_data_init):
                        with torch.no_grad():
                            bproj_samples = bproj(init_samples).detach()
                    else:
                        bproj_samples = torch.clone(init_samples).detach()

                    samples = bproj_samples + sigmas_th[
                        n_sigmas_skip] * torch.randn_like(bproj_samples)
                    samples.requires_grad = True
                    samples = samples.to(self.config.device)
                else:
                    samples = torch.rand(batch_size,
                                         self.config.target.data.channels,
                                         self.config.target.data.image_size,
                                         self.config.target.data.image_size,
                                         device=self.config.device)
                    init_samples = torch.clone(samples)
                    samples = data_transform(self.config.target, samples)
                    samples.requires_grad = True
                    samples = samples.to(self.config.device)

                all_samples = anneal_Langevin_dynamics(
                    samples,
                    Xs_global,
                    score,
                    cpat,
                    sigmas,
                    self.config.ncsn.sampling.n_steps_each,
                    self.config.ncsn.sampling.step_lr,
                    verbose=True,
                    final_only=self.config.ncsn.sampling.final_only,
                    denoise=self.config.ncsn.sampling.denoise,
                    n_sigmas_skip=n_sigmas_skip)

                samples = all_samples[-1]
                for img in samples:
                    img = inverse_data_transform(self.config.target, img)
                    save_image(
                        img,
                        os.path.join(self.args.image_folder,
                                     'image_{}.png'.format(img_id)))
                    img_id += 1

                if (self.args.save_labels):
                    save_path = os.path.join(self.args.image_folder, 'labels')
                    np.save(os.path.join(save_path, f'sources_{r}.npy'),
                            init_samples.detach().cpu().numpy())
                    np.save(os.path.join(save_path, f'source_labels_{r}.npy'),
                            labels.detach().cpu().numpy())
                    np.save(os.path.join(save_path, f"bproj_{r}.npy"),
                            bproj_samples.detach().cpu().numpy())
                    np.save(os.path.join(save_path, f"samples_{r}.npy"),
                            samples.detach().cpu().numpy())
Ejemplo n.º 22
0
    def step(self, closure: Optional[Callable] = None) -> Optional[float]:
        """Performs an analog-aware single optimization step.

        If a group containing analog parameters is detected, the optimization
        step calls the related RPU controller. For regular parameter groups,
        the optimization step has the same behaviour as ``torch.optim.SGD``.

        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.

        Returns:
            The loss, if ``closure`` has been passed as a parameter.
        """
        # pylint: disable=too-many-branches
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            learning_rate = group['lr']
            weight_decay = group['weight_decay']
            momentum = group['momentum']
            dampening = group['dampening']
            nesterov = group['nesterov']

            # Use analog_tile object.
            if group.get('analog_tile'):
                analog_tile = group['analog_tile']

                # Update learning rate
                analog_tile.set_learning_rate(learning_rate)

                weights = next(param for param in group['params']
                               if getattr(param, 'is_weight', False))

                # Call `update` in the tile.
                if weights.use_indexed:
                    analog_tile.update_indexed(weights.input, weights.grad_output)
                else:
                    analog_tile.update(weights.input, weights.grad_output)

                # Apply post-update step operations (diffuse, decay, etc).
                analog_tile.post_update_step()
                continue

            for param in group['params']:
                if param.grad is None:
                    continue
                d_p = param.grad
                if weight_decay != 0:
                    d_p = d_p.add(param, alpha=weight_decay)
                if momentum != 0:
                    param_state = self.state[param]
                    if 'momentum_buffer' not in param_state:
                        buf = param_state['momentum_buffer'] = clone(d_p).detach()
                    else:
                        buf = param_state['momentum_buffer']
                        buf.mul_(momentum).add_(d_p, alpha=1 - dampening)
                    if nesterov:
                        d_p = d_p.add(buf, alpha=momentum)
                    else:
                        d_p = buf

                param.add_(d_p, alpha=-group['lr'])

        return loss
Ejemplo n.º 23
0
import torch
Ejemplo n.º 24
0
    def step(self, closure=None):
        """Performs a single optimization step.

        Arguments:
            closure (callable, optional): 
                A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            for w in group['params']:
                if w.grad is None:
                    continue
                grad = w.grad.data

                if grad.is_sparse:
                    raise RuntimeError(
                        'Adam does not support sparse gradients, '
                        'please consider SparseAdam instead')

                amsgrad = group['amsgrad']

                state = self.state[w]

                # state initialization
                if len(state) == 0:
                    state['step'] = 0
                    # exponential moving average of gradient values
                    state['exp_avg'] = torch.zeros_like(w.data)
                    # exponential moving average of squared gradient values
                    state['exp_avg_sq'] = torch.zeros_like(w.data)
                    # moving average for the non-orthogonal projection scaling
                    state['exp_avg2'] = w.new(1).fill_(0)
                    if amsgrad:
                        # maintains max of all exp. moving avg.
                        # of sq. grad. values
                        state['max_exp_avg_sq'] = torch.zeros_like(w.data)

                exp_avg, exp_avg2, exp_avg_sq = \
                    state['exp_avg'], state['exp_avg2'], state['exp_avg_sq'],

                if amsgrad:
                    max_exp_avg_sq = state['max_exp_avg_sq']
                beta1, beta2 = group['betas']

                state['step'] += 1

                if group['weight_decay'] != 0:
                    grad.add_(group['weight_decay'], w.data)

                # if its SGD phase, take an SGD update and continue
                if group['phase'] == 'SGD':
                    if 'momentum_buffer' not in state:
                        buf = state['momentum_buffer'] = torch.clone(
                            grad).detach()
                    else:
                        buf = state['momentum_buffer']
                        buf.mul_(beta1).add_(grad)
                        grad = buf

                    grad.mul_(1 - beta1)
                    if group['nesterov']:
                        grad.add_(beta1, buf)

                    w.data.add_(-group['lr'], grad)
                    continue

                # decay the first and second moment running average coefficient
                exp_avg.mul_(beta1).add_(1 - beta1, grad)
                exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
                if amsgrad:
                    # maintains the maximum of all 2nd
                    # moment running avg. till now
                    torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
                    # use the max. for normalizing running avg. of gradient
                    denom = max_exp_avg_sq.sqrt().add_(group['eps'])
                else:
                    denom = exp_avg_sq.sqrt().add_(group['eps'])

                bias_correction1 = 1 - beta1**state['step']
                bias_correction2 = 1 - beta2**state['step']
                step_size = group['lr'] * \
                    (bias_correction2 ** 0.5) / bias_correction1

                p = -step_size * (exp_avg / denom)
                w.data.add_(p)

                p_view = p.view(-1)
                pg = p_view.dot(grad.view(-1))

                if pg != 0:
                    # the non-orthognal scaling estimate
                    scaling = p_view.dot(p_view) / -pg
                    exp_avg2.mul_(beta2).add_(1 - beta2, scaling)

                    # bias corrected exponential average
                    corrected_exp_avg = exp_avg2 / bias_correction2

                    # checking criteria of switching to SGD training
                    if state['step'] > 1 and \
                            corrected_exp_avg.allclose(scaling, rtol=1e-6) and \
                            corrected_exp_avg > 0:
                        group['phase'] = 'SGD'
                        group['lr'] = corrected_exp_avg.item()
                        if group['verbose']:
                            print('Switching to SGD after '
                                  '{} steps with lr {:.5f} '
                                  'and momentum {:.5f}.'.format(
                                      state['step'], group['lr'], beta1))

        return loss
Ejemplo n.º 25
0
def train_manipulator(model, data_loaders, args):
    """Train an emotion EBM."""
    device = args.device
    optimizer = Adam(model.parameters(), lr=args.lr, weight_decay=args.wd)
    model, optimizer, _, start_epoch, is_trained = load_from_ckpnt(
        args.classifier_ckpnt, model, optimizer, scheduler=None
    )
    if is_trained:
        return model
    writer = SummaryWriter('runs/' + args.checkpoint.replace('.pt', ''))

    # Training loop
    for epoch in range(start_epoch, args.epochs):
        print("Epoch: %d/%d" % (epoch + 1, args.epochs))
        kbar = pkbar.Kbar(target=len(data_loaders['train']), width=25)
        model.train()
        model.disable_batchnorm()
        model.zero_grad()
        # model.enable_grads()
        for step, ex in enumerate(data_loaders['train']):
            images, _, emotions, neg_images = ex
            # positive samples
            pos_samples = images.to(device)
            # prepare negative samples
            neg_samples, neg_masks = rand_mask(images.clone().to(device), device)
            # negative samples
            neg_ld_samples, neg_list = langevin_updates(
                model, torch.clone(neg_samples),
                args.langevin_steps, args.langevin_step_size,
                neg_masks
            )
            # Compute energy
            pos_out = model(pos_samples)
            neg_img_out = model(neg_images.to(device))
            neg_ld_out = model(neg_ld_samples.to(device))
            # Loss
            loss_reg = (pos_out**2 + neg_ld_out**2 + neg_img_out**2).mean()
            # loss_reg = (torch.abs(pos_out) + torch.abs(neg_ld_out) + torch.abs(neg_img_out)).mean()
            loss_ml = 2*pos_out.mean() - neg_ld_out.mean() - neg_img_out.mean()
            coeff = loss_ml.detach().clone() / loss_reg.detach().clone()
            loss = 0.5*loss_reg + loss_ml
            # if epoch == 0:
            #     loss = loss * 0.05
            '''
            loss = (
                pos_out**2 + neg_out**2 + neg_img_out**2 + neg_img_ld_out**2
                + 3*pos_out - neg_out - neg_img_out - neg_img_ld_out
            ).mean()
             '''
            # Step
            optimizer.zero_grad()
            loss.backward()
            clip_grad(model.parameters(), optimizer)
            optimizer.step()
            kbar.update(step, [("loss", loss)])
            # Log loss
            writer.add_scalar('energy/energy_pos', pos_out.mean().item(), epoch * len(data_loaders['train']) + step)
            writer.add_scalar('energy/energy_neg', neg_ld_out.mean().item(), epoch * len(data_loaders['train']) + step)
            writer.add_scalar('loss/loss_reg', loss_reg.item(), epoch * len(data_loaders['train']) + step)
            writer.add_scalar('loss/loss_ml', loss_ml.item(), epoch * len(data_loaders['train']) + step)
            writer.add_scalar('loss/loss_total', loss.item(), epoch * len(data_loaders['train']) + step)
            # Log image evolution
            if step % 50 != 0:
                continue
            writer.add_image(
                'random_image_sample',
                back2color(unnormalize_imagenet_rgb(pos_samples[0], device)),
                epoch * len(data_loaders['train']) + step
            )
            neg_list = [
                back2color(unnormalize_imagenet_rgb(neg, device))
                for neg in neg_list
            ]
            neg_list = [torch.zeros_like(neg_list[0])] + neg_list
            vid_to_write = torch.stack(neg_list, dim=0).unsqueeze(0)
            writer.add_video(
                'ebm_evolution', vid_to_write, fps=args.ebm_log_fps,
                global_step=epoch * len(data_loaders['train']) + step
            )
        writer.add_scalar(
            'lr', optimizer.state_dict()['param_groups'][0]['lr'], epoch
        )
        # Save checkpoint
        torch.save(
            {
                "epoch": epoch + 1,
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict()
            },
            args.classifier_ckpnt
        )
        torch.save(
            {
                "epoch": epoch + 1,
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict()
            },
            "manipulator_%02d.pt" % (epoch+1)
        )
        print('\nValidation')
        print(eval_manipulator(model, data_loaders['test'], args))
    return model
Ejemplo n.º 26
0
 def func():
     new_log_probs = model(inputs)
     old_log_probs = torch.clone(new_log_probs).detach()
     f = mean_kl_multinomial(new_log_probs, old_log_probs)
     return f, list(model.parameters())
Ejemplo n.º 27
0
    def test_memory_format_strides(self, device, dtype):
        shapes = (
            (),
            (0, ),
            (1, ),
            (5),
            (1, 0),
            (1, 1),
            (3, 7),
            (3, 0, 2),
            (1, 1, 2),
            (4, 1, 1),
            (7, 8, 9),
        )

        channels_last_shapes = ((0, 0, 0, 0), (1, 0, 3, 0), (0, 2, 3, 5),
                                (2, 2, 2, 0), (5, 4, 3, 2), (8, 8, 7, 2),
                                (9, 1, 3, 1), (4, 5, 8, 7))

        channels_last_3d_shapes = (
            (0, 8, 7, 9, 2),
            (5, 0, 7, 9, 2),
            (5, 0, 7, 9, 0),
            (5, 8, 7, 9, 2),
            (5, 1, 7, 9, 2),
            (5, 1, 7, 9, 1),
        )

        pairs = (
            (shapes, torch.contiguous_format),
            (channels_last_shapes, torch.contiguous_format),
            (channels_last_3d_shapes, torch.contiguous_format),
            (channels_last_shapes, torch.channels_last),
            (channels_last_3d_shapes, torch.channels_last_3d),
        )

        for shapes, memory_format in pairs:
            for shape in shapes:
                # tests empty
                expected = torch.empty(shape,
                                       device=device,
                                       dtype=dtype,
                                       memory_format=memory_format)
                actual = refs.empty(shape,
                                    device=device,
                                    dtype=dtype,
                                    memory_format=memory_format)
                self.assertEqual(expected.stride(), actual.stride())

                # tests clone
                a = torch.testing.make_tensor(shape,
                                              device=device,
                                              dtype=dtype)
                expected = torch.clone(a, memory_format=memory_format)
                actual = torch.clone(a, memory_format=memory_format)
                self.assertEqual(expected.stride(), actual.stride())

                # tests contiguous
                a = torch.testing.make_tensor(shape,
                                              device=device,
                                              dtype=dtype,
                                              noncontiguous=True)
                expected = a.contiguous(memory_format=memory_format)
                actual = refs.contiguous(a, memory_format=memory_format)
                self.assertEqual(expected.stride(), actual.stride())
Ejemplo n.º 28
0
    def step(self, closure=None):

        loss = None
        if closure is not None and isinstance(closure, collections.Callable):
            with torch.grad():
                loss = closure()

        param_size = 0
        variance_ma_sum = 0.0

        # phase 1 - accumulate all of the variance_ma_sum to use in stable weight decay

        for i, group in enumerate(self.param_groups):
            for j, p in enumerate(group["params"]):
                if p.grad is None:
                    continue

                # if not self.param_size:
                param_size += p.numel()

                # apply agc if enabled
                if self.agc_active:
                    self.agc(p)

                grad = p.grad

                if grad.is_sparse:
                    raise RuntimeError("sparse matrix not supported atm")

                state = self.state[p]
                momentum = group["momentum"]

                # State initialization
                if len(state) == 0:
                    # print("init state")
                    state["step"] = 0
                    # Exponential moving average of gradient values
                    state["grad_ma"] = torch.zeros_like(
                        p, memory_format=torch.preserve_format)
                    # Exponential moving average of squared gradient values
                    state["variance_ma"] = torch.zeros_like(
                        p, memory_format=torch.preserve_format)

                    if self.lookahead_active:
                        state["lookahead_params"] = torch.zeros_like(p.data)
                        state["lookahead_params"].copy_(p.data)

                    if self.use_adabelief:
                        state["variance_ma_belief"] = torch.zeros_like(
                            p, memory_format=torch.preserve_format)
                    if self.momentum_pnm:
                        state["neg_grad_ma"] = torch.zeros_like(
                            p, memory_format=torch.preserve_format)

                        # Maintains max of all exp. moving avg. of sq. grad. values
                        state["max_variance_ma"] = torch.zeros_like(
                            p, memory_format=torch.preserve_format)

                    # Cumulative products of beta1
                    # state["beta1_prod"] = torch.ones_like(
                    #    p.data, memory_format=torch.preserve_format
                    # )

                # centralize gradients
                if self.use_gc:
                    grad = centralize_gradient(
                        grad,
                        gc_conv_only=self.gc_conv_only,
                    )
                # else:
                #    grad = uncentralized_grad

                # phase 1, variance computations

                state["step"] += 1

                step = state["step"]
                lr = group["lr"]

                beta1, beta2 = group["betas"]
                grad_ma = state["grad_ma"]

                bias_correction2 = 1 - beta2**state["step"]
                # print(f"bias2 = {bias_correction2}")

                variance_ma = state["variance_ma"]
                if self.use_adabelief:
                    variance_ma_belief = state["variance_ma_belief"]

                # print(f"variance_ma, upper loop = {variance_ma}")

                # update the exp averages
                if self.use_adabelief:
                    grad_ma.mul_(beta1).add_(grad, alpha=1 - beta1)
                    grad_residual = grad - grad_ma
                    variance_ma_belief.mul_(beta2).addcmul(grad_residual,
                                                           grad_residual,
                                                           value=1 - beta2)
                # print(f"upper loop grad = {grad.shape}")
                variance_ma.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
                # print(f"variance_ma, grad adjusted")
                variance_ma_debiased = variance_ma / bias_correction2

                variance_ma_sum += variance_ma_debiased.sum()
                # print(f"variance_ma_sum = {variance_ma_sum}")
                # else: #madgrad

                # should we dupe variance_ma since stable is assuming adam style] variance?

                # stable wd
                # variance_ma_sum += grad_sum_sq.sum()

            # print(f"variance hat sum = {exp_avg_sq_hat_sum}")
            # Calculate the sqrt of the mean of all elements in exp_avg_sq_hat

            # we will run this first epoch only and then memoize
        if not self.param_size:
            self.param_size = param_size
            print(f"params size saved")
            print(f"total param groups = {i+1}")
            print(f"total params in groups = {j+1}")

        if not self.param_size:
            raise ValueError("failed to set param size")

        # debugging
        self.variance_sum_tracking.append(variance_ma_sum.item())

        # stable weight decay
        # if not self.use_madgrad:
        variance_normalized = math.sqrt(variance_ma_sum / param_size)

        # variance_mean = variance_ma_sum / param_size
        if math.isnan(variance_normalized):
            raise RuntimeError("hit nan for variance_normalized")
        # print(f"variance_mean = {variance_mean}")
        # print(f"variance_normalized = {variance_normalized}")
        # else:
        #    variance_normalized = math.pow((variance_ma / self.param_size), .3333)

        # print(f"variance mean sqrt = {variance_normalized}")

        # phase 2 - apply weight decay and step
        # ===========================================
        for group in self.param_groups:
            # print(f"In second phase loop")
            step = state["step"]

            # Perform stable weight decay
            decay = group["weight_decay"]
            eps = group["eps"]
            lr = group["lr"]
            momentum = group["momentum"]

            beta1, beta2 = group["betas"]

            # warmup
            # ======================
            if self.use_warmup and not self.warmup_complete:
                lr = self.warmup_dampening(lr, step)
                # print(f"lr = {lr}")

            # chebyshev
            # ===================
            if self.use_cheb and self.warmup_complete:
                lr = self.get_cheb_lr(lr, step)

            # warmdown
            # ==========
            if self.use_warm_down:
                lr = self.get_warm_down(lr, step)

            # madgrad outer
            ck = 1 - momentum
            lamb = lr * math.pow(step, 0.5)

            # stable decay and / or norm loss
            # ==================================
            if decay:
                if not self.use_madgrad:
                    # stable weight decay
                    p.data.mul_(1 - decay * lr / variance_normalized)
                else:
                    p.data.mul_(1 - decay * lamb / variance_normalized)

            if self.normloss_active:
                # apply norm loss
                unorm = self.unit_norm(p.data)
                correction = (2 * self.normloss_factor *
                              (1 - torch.div(1, unorm + self.eps)))
                p.mul_(1 - lr * correction)

            # innner loop, params
            for p in group["params"]:
                if p.grad is None:
                    continue

                state = self.state[p]
                inner_grad = p.grad

                if self.use_madgrad:
                    # ================== madgrad ============================
                    if "grad_sum_sq" not in state:
                        state["grad_sum_sq"] = torch.zeros_like(
                            p.data).detach()
                        state["s"] = torch.zeros_like(p.data).detach()
                        if momentum != 0:
                            state["x0"] = torch.clone(p.data).detach()

                    if momentum != 0.0 and grad.is_sparse:
                        raise RuntimeError(
                            "momentum != 0 is not compatible with sparse gradients"
                        )

                    # centralize gradients
                    if self.use_gc:
                        inner_grad = centralize_gradient(
                            inner_grad,
                            gc_conv_only=self.gc_conv_only,
                        )

                    grad_sum_sq = state["grad_sum_sq"]
                    s = state["s"]
                    if momentum == 0:
                        # Compute x_0 from other known quantities
                        rms = grad_sum_sq.pow(1 / 3).add_(eps)
                        x0 = p.data.addcdiv(s, rms, value=1)
                    else:
                        x0 = state["x0"]

                    # Accumulate second moments

                    # print(f" grad = {grad}")
                    # print(f"lamb = {lamb}")
                    # print(f"gsumsq = {grad_sum_sq}")

                    grad_sum_sq.addcmul_(inner_grad, inner_grad, value=lamb)
                    rms = grad_sum_sq.pow(1 / 3).add_(eps)

                    # Update s
                    s.data.add_(inner_grad, alpha=lamb)

                    # Step
                    if momentum == 0:
                        p.data.copy_(x0.addcdiv(s, rms, value=-1))
                    else:
                        z = x0.addcdiv(s, rms, value=-1)

                        # p is a moving average of z
                        p.data.mul_(1 - ck).add_(z, alpha=ck)

                else:  # adam with pnm core
                    # ============= adamW with pnm option ========================

                    grad = p.grad

                    beta1, beta2 = group["betas"]

                    grad_ma = state["grad_ma"]
                    variance_ma = state["variance_ma"]
                    if self.use_adabelief:
                        variance_ma_belief = state["variance_ma_belief"]

                    if self.momentum_pnm:

                        max_variance_ma = state["max_variance_ma"]

                        if state["step"] % 2 == 1:
                            grad_ma, neg_grad_ma = (
                                state["grad_ma"],
                                state["neg_grad_ma"],
                            )
                        else:
                            grad_ma, neg_grad_ma = (
                                state["neg_grad_ma"],
                                state["grad_ma"],
                            )

                    bias_correction1 = 1 - beta1**step
                    bias_correction2 = 1 - beta2**step

                    if self.momentum_pnm:
                        # Maintains the maximum of all 2nd moment running avg. till now
                        torch.max(max_variance_ma,
                                  variance_ma,
                                  out=variance_ma)
                        # Use the max. for normalizing running avg. of gradient
                        denom = (variance_ma.sqrt() /
                                 math.sqrt(bias_correction2)).add_(
                                     group["eps"])

                    # centralize gradients
                    if self.use_gc:
                        inner_grad = centralize_gradient(
                            inner_grad,
                            gc_conv_only=self.gc_conv_only,
                        )
                    if not self.use_adabelief:
                        grad_ma.mul_(beta1**2).add_(grad, alpha=1 - beta1**2)

                    noise_norm = math.sqrt((1 + beta2)**2 + beta2**2)

                    step_size = lr / bias_correction1

                    pnmomentum = (grad_ma.mul(1 + self.momentum_pnm).add(
                        neg_grad_ma,
                        alpha=-self.momentum_pnm).mul(1 / noise_norm))

                    p.addcdiv_(pnmomentum, denom, value=-step_size)

                    # denom = variance_biased_ma.sqrt().add(eps)

                    # step_size = lr / bias_correction1

                    # update weights
                    # p.data.add_(weight_mod, alpha=-step_size)
                    # p.addcdiv_(grad_ma, denom, value=-step_size)
        # print(f"\n End optimizer step\n")

        # end of step processes....

        # lookahead
        # ---------------------
        if self.lookahead_active:
            self.lookahead_process_step()

        self.track_epochs(step)
        return loss
Ejemplo n.º 29
0
def test_msa_check_base_tokenization():

    # Loop over tests
    for test_data in generate_test_objects():

        # Skip if this is not the MSA transformer
        if test_data["ModelName"] != "esm_msa1_t12_100M_UR50S":
            continue

        # Get a sequence, combo, and target positions for running tests
        random_sequence, parent_combo, target_positions = build_seq_data(
            test_data)

        # Make a base tokenization
        model = test_data["Model"]
        base_tokenization = model._build_base_tokenization(random_sequence)

        # Error if the tokenization dimensionality is off
        with pytest.raises(AssertionError,
                           match="Expected 1 element in first dimension"):
            model._check_base_tokenization(torch.rand(2, 1,
                                                      1), random_sequence,
                                           target_positions, parent_combo)

        with pytest.raises(AssertionError, match="Incorrect token dim"):
            model._check_base_tokenization(torch.rand(1, 1), random_sequence,
                                           target_positions, parent_combo)

        # Error if the combo and mutant positions are off
        fake_parent = [None] * len(parent_combo)
        for i, parent_aa in enumerate(parent_combo):
            allowed_muts = [mut for mut in ALL_AAS if mut != parent_combo[i]]
            fake_parent[i] = random.choice(allowed_muts)

        with pytest.raises(
                AssertionError,
                match="Unaligned parent combo and mutant positions"):
            model._check_base_tokenization(base_tokenization, random_sequence,
                                           target_positions, fake_parent)

        # Error if the number of alignments is off
        bad_tokenization = torch.cat((base_tokenization, base_tokenization),
                                     axis=1)
        with pytest.raises(AssertionError,
                           match="Incorrect tokenization of alignments"):
            model._check_base_tokenization(bad_tokenization, random_sequence,
                                           target_positions, parent_combo)

        # Error if the sequence length is off
        bad_tokenization2 = torch.cat(
            (base_tokenization,
             torch.ones(*base_tokenization.shape, dtype=torch.long)),
            axis=2)
        with pytest.raises(AssertionError,
                           match="Expect addition of cls. Refseq length off."):
            model._check_base_tokenization(bad_tokenization2, random_sequence,
                                           target_positions, parent_combo)

        # Error if cls isn't first
        no_cls = torch.clone(base_tokenization)
        no_cls[:, :, 0] = model.tok_to_idx[model.eos_string]
        with pytest.raises(AssertionError, match="Expect addition of cls"):
            model._check_base_tokenization(no_cls, random_sequence,
                                           target_positions, parent_combo)

        # Pass in a bad sequence. We should error.
        seqlen = len(random_sequence[0][1])
        bad_seq = [random.choice(ALL_AAS) for _ in range(seqlen)]
        random_sequence[0][1] = bad_seq
        with pytest.raises(AssertionError,
                           match="Tokenization does not represent alignment"):
            model._check_base_tokenization(base_tokenization, random_sequence,
                                           target_positions, parent_combo)
Ejemplo n.º 30
0
 loss = lsInfo["conf"] + lsInfo["box"] + lsInfo["cls"] * bool(
     cfg.model.clsNum - 1)
 loss = loss / cfg.train.batchSize
 l1, l2 = Regularization(network)
 loss += cfg.loss.l2 * l2
 """backward"""
 optimizer.zero_grad()
 loss.backward()
 nn.utils.clip_grad_value_(network.parameters(),
                           100)  # gradient clip
 optimizer.step()
 # scheduler.step(loss) #可以使其他的指标
 """print"""
 niter = e * len(trainLoader) + id + 1
 with torch.no_grad():
     lossS = torch.clone(loss).to('cpu').numpy()
     lsConf = torch.clone(lsInfo['conf']).to('cpu').numpy()
     lsBox = torch.clone(lsInfo['box']).to('cpu').numpy()
     lsCls = torch.clone(lsInfo['cls']).to('cpu').numpy()
     if id % 30 == 0:
         print(
             "[bc:{}/{} e: {}/{} total_bc:{} per:{:.3f}%]".format(
                 id, len(trainLoader), e, cfg.train.epoch, batchNum,
                 float(niter * 100) / batchNum),
             "loss:[{:.4f} conf:{:.4f} cls:{:.4f} box:{:.4f} l2:{:.4f} lr:{:.7f}"
             .format(lossS / cfg.train.batchSize,
                     lsConf / cfg.train.batchSize,
                     lsCls / cfg.train.batchSize,
                     lsBox / cfg.train.batchSize, l2, lr))
 """tensorboardX to view"""
 #loss add per iter