Example #1
0
def calc_gammas(batch, gamma):
    '''Calculate the gammas to the right power for multiplication with rewards'''
    news = torch.cat([torch.ones((1,)), batch['dones'][:-1]])
    gammas = torch.empty_like(news)
    cur_gamma = 1.0
    for t, new in enumerate(news):
        cur_gamma = new * 1.0 + (1 - new) * cur_gamma * gamma
        gammas[t] = cur_gamma
    return gammas
Example #2
0
 def func(x):
     out = torch.empty_like(x)
     in_bound = (lower <= x) * (x <= upper)
     out[~in_bound] = filler
     if upper == lower:
         out[in_bound] = hist_sum
     else:
         pos = (x[in_bound] - lower) * factor
         index, frac = pos.long(), pos % 1
         next_index = (index + 1).clamp(max=hist.size(0) - 1)
         out[in_bound] = (1 - frac) * hist[index] + frac * hist[next_index]
     return out
Example #3
0
 def forward(self, x):
     out = torch.empty_like(x)
     return out
Example #4
0
# torch.rand
# torch.empty
# torch.zeros
# torch.randn_like()
# x.new_ones() #xを置き換えて生成
x = x.new_ones(5, 3, dtype=torch.double)

# テンソルサイズ
print(x.size())
print(x.shape)

################# テンソル演算
torch.add(x, y)

# 出力先のテンソルを out 引数に指定することができる
result = torch.empty_like(x)
torch.add(x,y, out=result)

# リサイズ view関数をつかう
x = torch.randn(4,4)
y = x.view(16)
z = x.view(-1, 8)
print(x.size(), y.shape, z.shape)

# inplace 処理
# メソッドの後に _ をつけると、変数に上書きされる
print(y)
print(y.add(x)) 
print(y) #y は変化なし
y.add_(x)
Example #5
0
def main(args):
    chosen_method_log = dict()  # this writes things when method is changed
    current_method_log = dict()  # this will monitor what is the current method
    candidate_method_stat = dict(
    )  # this tracks the thresh for all candidate method
    timing_log = defaultdict(list)
    floats_communicated = dict()
    grad_calc_dict = dict()
    ratio_calc_dict = dict()
    prev_norm = None
    json_f_name = os.path.basename(args.norm_file).split('.')[0] + '.json'
    current_method_log_fname = os.path.basename(
        args.norm_file).split('.')[0] + "_per_epoch_method.json"
    candidate_methods_stat_fname = os.path.basename(
        args.norm_file).split('.')[0] + "_candidate_method_stats.json"
    timing_log_fname = os.path.basename(
        args.norm_file).split('.')[0] + "_timing_log.json"
    bytes_log_fname = os.path.basename(
        args.norm_file).split('.')[0] + "_floats_communicated.json"
    ratio_log_fname = os.path.basename(
        args.norm_file).split('.')[0] + "_ratio_vals.json"
    grad_calc_fname = os.path.basename(
        args.norm_file).split('.')[0] + "_grad_norm_vals.json"
    #TODO: Clean this up to manually select the model
    if args.model_type == "CNN":
        config = cifar_config
    elif args.model_type == "languageModel":
        config = lstm_config
    elif args.model_type == "newlanguageModel":
        config = new_lstm_config
    elif args.model_type == "imagenet":
        config = imagenet_config
    elif args.model_type == "cifar100":
        config = cifar100_config
    elif args.model_type == "svhn":
        config = svhn_config
    elif args.model_type == "squeezenet_cifar":
        config = cifar_squeezenet_config
    else:
        raise NotImplemented("{} not NotImplemented".format(args.model_type))
    config['is_distributed'] = False  # adding a new key in the config
    # overriding the network with user input
    config['arch'] = args.network
    if args.distributed:
        print("Initializing distributed")
        dist.init_process_group(backend="NCCL",
                                init_method=args.master_ip,
                                timeout=datetime.timedelta(seconds=120),
                                world_size=args.num_nodes,
                                rank=args.rank)
        config['is_distributed'] = True
        print("Distributed Initialized")
    train_task = train_network.build(config['dataset'], config)
    #TODO: Fix this for distributed
    # use parameter groups to get things for different learning rates
    # and weight decay parameters
    current_lr = config['init_lr']
    if config['name'] == "CNN" or config['name'] == 'cifar100' or config[
            'name'] == 'imagenet' or config['name'] == 'svhn':
        # optimizer only for langauge model
        # otherwise we are going manual\
        # my guess is that repackage thing for language models changes
        # the model structure and the optimizer is registered only for some of
        # the parameters
        optimizer = optim.SGD(train_task.model.parameters(),
                              lr=current_lr,
                              momentum=config['momentum'],
                              weight_decay=0.0001)

    if config['name'] == "squeezenet_cifar":
        # special optimizer for squeezenet
        optimizer = optim.SGD(train_task.model.parameters(),
                              lr=current_lr,
                              momentum=config['momentum'],
                              weight_decay=5e-4)

    # list containing applySparsify class collection
    # the applySparsify method will handle everything
    # None if no need for reduction for the corresponding
    sparsify_method = [
        sparsify_gradient_topk.applySparsify(p.shape, config['device'])
        if p.ndimension() > 1 else None for p in train_task.model.parameters()
    ]
    # import ipdb; ipdb.set_trace()
    # Temporay to test code with fixed k
    if not args.fixed_k and not args.auto_switch:
        print("Warning: Full Rank SGD being done")
    if args.fixed_k:
        print("Chose a fixed k, k= {}".format(args.k))
        for m in sparsify_method:
            if m is not None:
                m.update_method(args.k, args.zero_memory)
            else:
                pass
    if args.start_k:
        print("Starting with fixed k ={}".format(args.k_start))
        for m in sparsify_method:
            if m is not None:
                m.update_method(args.k_start, args.zero_memory)
            else:
                pass
    current_test_loss = None
    best_test_loss = None
    momenta = [
        torch.empty_like(param) for param in train_task.model.parameters()
    ]
    first_iter = 0  # hack for momentum code

    for epoch in range(config['num_epochs']):
        step_iter = train_task.train_single_iter(epoch=epoch,
                                                 logger=logger,
                                                 for_autoscale=False)
        # i think somebody is not cleaning up the gradients and that's causing
        # the problem
        # train_task.model.zero_grad()
        # train_task.model.train()
        if args.fixed_sched:
            print("Following Fixed schedule")
            if epoch == 20:
                for idx, m in enumerate(sparsify_method):
                    if m is not None:
                        # if idx <= 68:
                        # m.update_method(1, args.zero_memory)
                        # else:
                        m.update_method(args.k_start, args.zero_memory)
                    else:
                        pass
            # if epoch == 110:
            #    for m in sparsify_method:
            #        if m is not None:
            #            m.update_method(4, args.zero_memory)
            #        else:
            #            pass

            #if epoch == 130:
            #    for m in sparsify_method:
            #        if m is not None:
            #            m.update_method(4, args.zero_memory)
            #        else:
            #            pass
            if epoch == 150:
                for m in sparsify_method:
                    if m is not None:
                        m.update_method(args.k_start, args.zero_memory)
                    else:
                        pass

            if epoch == 170:
                for m in sparsify_method:
                    if m is not None:
                        m.update_method(args.k_start, args.zero_memory)
                    else:
                        pass
            if epoch == 250:
                for m in sparsify_method:
                    if m is not None:
                        m.update_method(args.k_start, args.zero_memory)
                    else:
                        pass

            if epoch == 260:
                for m in sparsify_method:
                    if m is not None:
                        m.update_method(args.k_start, args.zero_memory)
                    else:
                        pass

        tic = time.time()
        elements_per_epoch = 0
        # if epoch != 0:
        # print("Norm of gradients before starting {} at epoch {}".format([
        # torch.norm(l.grad.data).item() for l in train_task.model.parameters()]
        # ,epoch))
        # net = {
        # 'state': train_task.model.state_dict()
        # }
        # torch.save(net, "epoch_{}_before_training.pth".format(epoch))
        full_rank_accum = [
            torch.zeros_like(copy_l)
            for copy_l in train_task.model.parameters()
        ]
        for grad_train in step_iter:
            # TODO: Think carefully how you want to modify the gradients
            out_grad_list = list()  #list to store output gradients
            for idx, grad_val in enumerate(grad_train):
                full_rank_accum[idx].add_(grad_val.data)
                sparse_object = sparsify_method[idx]
                if sparse_object is not None:
                    out_grad_reduced, bytes_comm = sparse_object.apply_method(
                        grad_val)
                    # out_grad_list.append(sparse_object.apply_method(grad_val))
                    out_grad_list.append(out_grad_reduced)
                    elements_per_epoch += bytes_comm
                else:
                    # in case of distributed need to all reduce the singular
                    # values
                    if args.distributed:
                        elements_per_epoch += torch.numel(grad_val)
                        torch.distributed.all_reduce(grad_val, async_op=False)
                        grad_val[:] = grad_val / args.num_nodes
                    out_grad_list.append(grad_val)
            # updated the gradients in place
            # TODO: Move this to a new function
            for idx, param in enumerate(train_task.model.parameters()):
                param.grad.data = out_grad_list[idx]
            if config['name'] == 'CNN' or config[
                    'name'] == 'cifar100' or config[
                        'name'] == 'imagenet' or config['name'] == 'svhn':
                optimizer.step()
                optimizer.zero_grad()
            if config['name'] == "squeezenet_cifar":
                optimizer.step()
                optimizer.zero_grad()
            elif config['name'] == 'languageModel' or config[
                    'name'] == 'newlanguageModel':
                # momentum implementation
                for idx, param in enumerate(train_task.model.parameters()):
                    if epoch == 0 and first_iter == 0:
                        momenta[idx].data = param.grad.data.clone().detach()
                        first_iter = 1
                    else:
                        momenta[idx].data.mul_(0.9).add_(param.grad.data)
                    param.grad.data[:] += momenta[idx].data

                for p in train_task.model.parameters():
                    p.data.add_(-current_lr, p.grad.data)
                train_task.model.zero_grad()
        toc = time.time()
        timing_log[epoch].append(tic)
        timing_log[epoch].append(toc)
        floats_communicated[epoch] = elements_per_epoch
        grad_calc_dict[epoch] = [
            torch.norm(pval).item() for pval in full_rank_accum
        ]
        # dumping training method used every epoch
        # mostly for sanity checking
        # commenting out for future use

        # if epoch%10 == 0:
        # if args.rank == 0:
        # # net = {
        # # 'state': train_task.model.state_dict()
        # # }
        # # torch.save(net, "./saved_model.pth")

        # # norm_list = train_task.get_train_norm("./saved_model.pth",
        # # config)
        # # print (norm_list)
        # # grad_calc_dict[epoch] = norm_list
        # if epoch == 0:
        # old_grad_norms = None
        # else:
        # old_grad_norms = grad_calc_dict[epoch-10]
        # current_grad_norms = grad_calc_dict[epoch]
        # auto_scale_list, ratio_list =run_auto_scale(
        # current_grad_norms, old_grad_norms, epoch)

        # else:
        # time.sleep(5)
        torch.distributed.barrier()
        method_array = list()
        for mth_sp in sparsify_method:
            if mth_sp is None:
                method_array.append("FullRank")
            elif mth_sp.k is None:
                method_array.append("FullRank")
            else:
                method_array.append(mth_sp.k)
        current_method_log[epoch] = method_array
        with open(current_method_log_fname, "w") as fout:
            json.dump(current_method_log, fout)
        with open(timing_log_fname, "w") as fout:
            json.dump(timing_log, fout)
        with open(bytes_log_fname, "w") as fout:
            json.dump(floats_communicated, fout)
        with open(grad_calc_fname, "w") as fout:
            json.dump(grad_calc_dict, fout)
        # import ipdb; ipdb.set_trace()
        if args.auto_switch:
            print("Auto switching enabled")
            if epoch % config['switch_freq'] == 0:
                #TOD$O: Make acceptable k from args of config dict
                auto_scale_tensor = torch.zeros(len(sparsify_method),
                                                device="cuda:0",
                                                dtype=torch.float32)
                if args.rank == 0:
                    # only doing it for master
                    #TODO: Make that 4 configurable
                    # ratio_val, prev_norm, auto_scale_per_layer = auto_scale.run_auto_scale_gng(train_task,
                    # 4, args.norm_thresh, prev_norm)

                    if epoch == 0:
                        old_grad_norms = None
                    else:
                        old_grad_norms = grad_calc_dict[epoch -
                                                        config['switch_freq']]
                        # will give the previous grads
                    current_grad_norms = grad_calc_dict[epoch]
                    auto_scale_per_layer, ratio_val = auto_scale_topk.run_auto_scale_gng(
                        current_grad_norms, old_grad_norms, epoch)
                    # auto_scale_divergence_list = auto_scale.run_auto_scale_divergence(
                    # grad_calc_dict, epoch, config['num_epochs'],
                    # config['switch_freq'])
                    # if auto_scale_divergence_list is not None:
                    # for idx, value_in in enumerate(auto_scale_per_layer):
                    # auto_scale_per_layer[idx] = max(
                    # auto_scale_per_layer[idx],
                    # auto_scale_divergence_list[idx])
                    #CAUTION: Bad hack to dump values and test
                    # auto_scale_per_layer = [4]*len(auto_scale_tensor)

                    print("Auto scale per layer calculated = {} at rank {}".
                          format(auto_scale_per_layer, args.rank))
                    # there could be None in auto_scale_per_layer
                    # to clean that up I use this map
                    #TODO: Add flags and condition checks for single machine
                    auto_scale_per_layer = list(
                        map(lambda x: 999
                            if x == None else x, auto_scale_per_layer))

                    auto_scale_tensor = torch.tensor(
                        auto_scale_per_layer, dtype=torch.float32).to('cuda:0')
                # broadcast autoscale values
                print("Auto scale tensor before = {} for rank {}".format(
                    auto_scale_tensor, args.rank))
                torch.distributed.broadcast(auto_scale_tensor, 0)
                print("Auto Scale Tensor after = {} for rank {}".format(
                    auto_scale_tensor, args.rank))
                auto_scale_per_layer = auto_scale_tensor.tolist()
                # substiuting None back
                auto_scale_per_layer = list(
                    map(lambda x: None
                        if x == 999 else x, auto_scale_per_layer))
                print("Auto scale list = {} for rank {}".format(
                    auto_scale_per_layer, args.rank))
                if args.rank == 0:
                    candidate_method_stat[epoch] = prev_norm
                    ratio_calc_dict[epoch] = ratio_val
                for idx, spm in enumerate(auto_scale_per_layer):
                    chosen_method = auto_scale_per_layer[idx]
                    sparse_mth = sparsify_method[idx]
                    if sparse_mth is not None:
                        sparse_mth.update_method(chosen_method,
                                                 args.zero_memory)
                    else:
                        auto_scale_per_layer[
                            idx] = None  # so that json is clean
                chosen_method_log[epoch] = auto_scale_per_layer
                with open(json_f_name, "w") as fout:
                    json.dump(chosen_method_log, fout)
                with open(candidate_methods_stat_fname, "w") as fout:
                    json.dump(candidate_method_stat, fout)
                with open(ratio_log_fname, "w") as fout:
                    json.dump(ratio_calc_dict, fout)

        train_task.model.eval()
        current_test_loss = train_task.validate_model(logger)
        if not best_test_loss or current_test_loss < best_test_loss:
            best_test_loss = current_test_loss
        # updating the learning rate
        prev_lr = current_lr
        if config['name'] != "squeezenet_cifar":
            current_lr = get_lr(config, epoch, current_lr, best_test_loss,
                                current_test_loss)
        else:
            current_lr, current_wd = get_lr_squeezenet(config, epoch)

        if current_lr < prev_lr:
            # Second rule of new auto scale
            # at decay point what to do
            if args.auto_switch:
                print("Epoch {} deacy time making it k= {}".format(
                    epoch, auto_scale_high))
                for m in sparsify_method:
                    if m is not None:
                        m.update_method(auto_scale_high)
                    else:
                        pass

        train_task.lr = current_lr  # mostly for logging
        #TODO: Add one more logging to make sure that k is correct
        # this will read the sparsify method array and write out the

        if config['name'] == 'CNN' or config['name'] == 'cifar100' or config[
                'name'] == 'svhn' or config['name'] == 'imagenet':
            for group in optimizer.param_groups:
                group['lr'] = current_lr
        if config['name'] == "squeezenet_cifar":
            for group in optimizer.param_groups:
                group['lr'] = current_lr
                group['weight_decay'] = current_wd
Example #6
0
    def __init__(
        self,
        n: Optional[int] = None,
        shape: Optional[Iterable[int]] = None,
        traces: bool = False,
        traces_additive: bool = False,
        tc_trace: Union[float, torch.Tensor] = 20.0,
        trace_scale: Union[float, torch.Tensor] = 1.0,
        sum_input: bool = False,
        thresh: Union[float, torch.Tensor] = -52.0,
        rest: Union[float, torch.Tensor] = -65.0,
        reset: Union[float, torch.Tensor] = -65.0,
        refrac: Union[int, torch.Tensor] = 5,
        tc_decay: Union[float, torch.Tensor] = 100.0,
        theta_plus: Union[float, torch.Tensor] = 0.05,
        tc_theta_decay: Union[float, torch.Tensor] = 1e7,
        lbound: float = None,
        one_spike: bool = True,
        **kwargs,
    ) -> None:
        # language=rst
        """
        Instantiates a layer of Diehl & Cook 2015 neurons.

        :param n: The number of neurons in the layer.
        :param shape: The dimensionality of the layer.
        :param traces: Whether to record spike traces.
        :param traces_additive: Whether to record spike traces additively.
        :param tc_trace: Time constant of spike trace decay.
        :param trace_scale: Scaling factor for spike trace.
        :param sum_input: Whether to sum all inputs.
        :param thresh: Spike threshold voltage.
        :param rest: Resting membrane voltage.
        :param reset: Post-spike reset voltage.
        :param refrac: Refractory (non-firing) period of the neuron.
        :param tc_decay: Time constant of neuron voltage decay.
        :param theta_plus: Voltage increase of threshold after spiking.
        :param tc_theta_decay: Time constant of adaptive threshold decay.
        :param lbound: Lower bound of the voltage.
        :param one_spike: Whether to allow only one spike per timestep.
        """
        super().__init__(
            n=n,
            shape=shape,
            traces=traces,
            traces_additive=traces_additive,
            tc_trace=tc_trace,
            trace_scale=trace_scale,
            sum_input=sum_input,
        )

        self.register_buffer("rest", torch.tensor(rest))  # Rest voltage.
        self.register_buffer("reset",
                             torch.tensor(reset))  # Post-spike reset voltage.
        self.register_buffer("thresh",
                             torch.tensor(thresh))  # Spike threshold voltage.
        self.register_buffer(
            "refrac", torch.tensor(refrac))  # Post-spike refractory period.
        self.register_buffer(
            "tc_decay",
            torch.tensor(tc_decay))  # Time constant of neuron voltage decay.
        self.register_buffer("decay", torch.empty_like(
            self.tc_decay))  # Set in compute_decays.
        self.register_buffer(
            "theta_plus",
            torch.tensor(theta_plus))  # Constant threshold increase on spike.
        self.register_buffer("tc_theta_decay", torch.tensor(
            tc_theta_decay))  # Time constant of adaptive threshold decay.
        self.register_buffer(
            "theta_decay",
            torch.empty_like(self.tc_theta_decay))  # Set in compute_decays.
        self.register_buffer("v", torch.FloatTensor())  # Neuron voltages.
        self.register_buffer("theta",
                             torch.zeros(*self.shape))  # Adaptive thresholds.
        self.register_buffer(
            "refrac_count", torch.FloatTensor())  # Refractory period counters.

        self.lbound = lbound  # Lower bound of voltage.
        self.one_spike = one_spike  # One spike per timestep.
Example #7
0
    def __init__(
        self,
        n: Optional[int] = None,
        shape: Optional[Iterable[int]] = None,
        traces: bool = False,
        traces_additive: bool = False,
        tc_trace: Union[float, torch.Tensor] = 20.0,
        trace_scale: Union[float, torch.Tensor] = 1.0,
        sum_input: bool = False,
        learning: bool = True,
        **kwargs,
    ) -> None:
        # language=rst
        """
        Abstract base class constructor.

        :param n: The number of neurons in the layer.
        :param shape: The dimensionality of the layer.
        :param traces: Whether to record decaying spike traces.
        :param traces_additive: Whether to record spike traces additively.
        :param tc_trace: Time constant of spike trace decay.
        :param trace_scale: Scaling factor for spike trace.
        :param sum_input: Whether to sum all inputs.
        :param learning: Whether to be in learning or testing.
        """
        super().__init__()

        assert (n is not None or shape is not None
                ), "Must provide either no. of neurons or shape of layer"

        if n is None:
            self.n = reduce(mul, shape)  # No. of neurons product of shape.
        else:
            self.n = n  # No. of neurons provided.

        if shape is None:
            self.shape = [self.n]  # Shape is equal to the size of the layer.
        else:
            self.shape = shape  # Shape is passed in as an argument.

        assert self.n == reduce(
            mul, self.shape), "No. of neurons and shape do not match"

        self.traces = traces  # Whether to record synaptic traces.
        self.traces_additive = (traces_additive
                                )  # Whether to record spike traces additively.
        self.register_buffer("s", torch.ByteTensor())  # Spike occurrences.

        self.sum_input = sum_input  # Whether to sum all inputs.

        if self.traces:
            self.register_buffer("x", torch.Tensor())  # Firing traces.
            self.register_buffer(
                "tc_trace",
                torch.tensor(tc_trace))  # Time constant of spike trace decay.
            if self.traces_additive:
                self.register_buffer("trace_scale", torch.tensor(
                    trace_scale))  # Scaling factor for spike trace.
            self.register_buffer("trace_decay", torch.empty_like(
                self.tc_trace))  # Set in compute_decays.

        if self.sum_input:
            self.register_buffer("summed",
                                 torch.FloatTensor())  # Summed inputs.

        self.dt = None
        self.batch_size = None
        self.trace_decay = None
        self.learning = learning
Example #8
0
 def color_augmentation(img):
     for c in range(3):
         img[:, :, c].mul_(torch.empty_like(img[:, :, c]).uniform_(0.6, 1.4)) \
             .clamp_(0, 1)
     return img
Example #9
0
def _make_sparse(grad, grad_indices, values):
    size = grad.size()
    if grad_indices.numel() == 0 or values.numel() == 0:
        return torch.empty_like(grad)
    return torch.sparse_coo_tensor(grad_indices, values, size)
Example #10
0
    def forward(self, input, weight, bias, running_mean, running_var, eps, momentum, process_group, world_size):
        if not input.is_contiguous(memory_format=torch.channels_last):
            input = input.contiguous()
        if weight is not None:
            weight = weight.contiguous()

        size = int(input.numel() // input.size(1))
        if size == 1 and world_size < 2:
            raise ValueError('Expected more than 1 value per channel when training, got input size {}'.format(size))

        num_channels = input.shape[1]
        if input.numel() > 0:
            # calculate mean/invstd for input.
            mean, invstd = torch.batch_norm_stats(input, eps)

            count = torch.full(
                (1,),
                input.numel() // input.size(1),
                dtype=mean.dtype,
                device=mean.device
            )

            # C, C, 1 -> (2C + 1)
            combined = torch.cat([mean, invstd, count], dim=0)
        else:
            # for empty input, set stats and the count to zero. The stats with
            # zero count will be filtered out later when computing global mean
            # & invstd, but they still needs to participate the all_gather
            # collective communication to unblock other peer processes.
            combined = torch.zeros(
                2 * num_channels + 1,
                dtype=input.dtype,
                device=input.device
            )

        # Use allgather instead of allreduce because count could be different across
        # ranks, simple all reduce op can not give correct results.
        # batch_norm_gather_stats_with_counts calculates global mean & invstd based on
        # all gathered mean, invstd and count.
        # for nccl backend, use the optimized version of all gather.
        if process_group._get_backend_name() == 'nccl':
            # world_size * (2C + 1)
            combined_size = combined.numel()
            combined_flat = torch.empty(1,
                                        combined_size * world_size,
                                        dtype=combined.dtype,
                                        device=combined.device)
            dist._all_gather_base(combined_flat, combined, process_group, async_op=False)
            combined = torch.reshape(combined_flat, (world_size, combined_size))
            # world_size * (2C + 1) -> world_size * C, world_size * C, world_size * 1
            mean_all, invstd_all, count_all = torch.split(combined, num_channels, dim=1)
        else:
            # world_size * (2C + 1)
            combined_list = [
                torch.empty_like(combined) for _ in range(world_size)
            ]
            dist.all_gather(combined_list, combined, process_group, async_op=False)
            combined = torch.stack(combined_list, dim=0)
            # world_size * (2C + 1) -> world_size * C, world_size * C, world_size * 1
            mean_all, invstd_all, count_all = torch.split(combined, num_channels, dim=1)

        if not torch.cuda.is_current_stream_capturing():
            # The lines below force a synchronization between CUDA and CPU, because
            # the shape of the result count_all depends on the values in mask tensor.
            # Such synchronizations break CUDA Graph capturing.
            # See https://github.com/pytorch/pytorch/issues/78549
            # FIXME: https://github.com/pytorch/pytorch/issues/78656 describes
            # a better longer-term solution.

            # remove stats from empty inputs
            mask = count_all.squeeze(-1) >= 1
            count_all = count_all[mask]
            mean_all = mean_all[mask]
            invstd_all = invstd_all[mask]

        # calculate global mean & invstd
        mean, invstd = torch.batch_norm_gather_stats_with_counts(
            input,
            mean_all,
            invstd_all,
            running_mean,
            running_var,
            momentum,
            eps,
            count_all.view(-1)
        )

        self.save_for_backward(input, weight, mean, invstd, count_all.to(torch.int32))
        self.process_group = process_group

        # apply element-wise normalization
        if input.numel() > 0:
            return torch.batch_norm_elemt(input, weight, bias, mean, invstd, eps)
        else:
            return torch.empty_like(input)
Example #11
0
def correct_phases(reals, imags, dim, inv=False):
    """ Corrects wavelet phases so the centres line up.

    i.e. This makes sure the centre of the real wavelet is a zero crossing for
    all orientations.

    For the inverse path, we needed to multiply the coefficients by the phase
    correcting factor of: [1j, -1j, 1j, -1, 1, -1].
    For the forward path, we divide by these numbers, or multiply by their
    complex conjugate.

    Parameters
    ----------
    reals: torch.tensor of floats with the 6 orientations in the second
        dimension
    imags: torch.tensor of floats with the 6 orientations in the second
        dimension
    inv : bool
        Whether this is the forward or backward pass.  Default is false (i.e.
        forward)
    """
    r = torch.empty_like(reals)
    i = torch.empty_like(imags)
    if dim == 2:
        if inv:
            m1 = torch.tensor([1, -1, 1],
                              dtype=imags.dtype,
                              device=imags.device)
            m2 = torch.tensor([-1, 1, -1],
                              dtype=imags.dtype,
                              device=imags.device)
            m1 = m1.view(1, 1, 3, 1, 1)
            m2 = m2.view(1, 1, 3, 1, 1)
        else:
            m1 = torch.tensor([-1, 1, -1],
                              dtype=imags.dtype,
                              device=imags.device)
            m2 = torch.tensor([-1, 1, -1],
                              dtype=imags.dtype,
                              device=imags.device)
            m1 = m1.view(1, 1, 3, 1, 1)
            m2 = m2.view(1, 1, 3, 1, 1)
        r[:, :, :3] = imags[:, :, :3] * m1
        i[:, :, :3] = reals[:, :, :3] * -m1
        r[:, :, 3:] = reals[:, :, 3:] * m2
        i[:, :, 3:] = imags[:, :, 3:] * m2
    elif dim == 1:
        if inv:
            m1 = torch.tensor([1, -1, 1],
                              dtype=imags.dtype,
                              device=imags.device)
            m2 = torch.tensor([-1, 1, -1],
                              dtype=imags.dtype,
                              device=imags.device)
            m1 = m1.view(1, 3, 1, 1)
            m2 = m2.view(1, 3, 1, 1)
        else:
            m1 = torch.tensor([-1, 1, -1],
                              dtype=imags.dtype,
                              device=imags.device)
            m2 = torch.tensor([-1, 1, -1],
                              dtype=imags.dtype,
                              device=imags.device)
            m1 = m1.view(1, 3, 1, 1)
            m2 = m2.view(1, 3, 1, 1)
        r[:, :3] = imags[:, :3] * m1
        i[:, :3] = reals[:, :3] * -m1
        r[:, 3:] = reals[:, 3:] * m2
        i[:, 3:] = imags[:, 3:] * m2
    return r, i
Example #12
0
print(x.stride())  # Ouputs: (3072, 1, 96, 3)

######################################################################
# ``clone`` preserves memory format
y = x.clone()
print(y.stride())  # Ouputs: (3072, 1, 96, 3)

######################################################################
# ``to``, ``cuda``, ``float`` ... preserves memory format
if torch.cuda.is_available():
    y = x.cuda()
    print(y.stride())  # Ouputs: (3072, 1, 96, 3)

######################################################################
# ``empty_like``, ``*_like`` operators preserves memory format
y = torch.empty_like(x)
print(y.stride())  # Ouputs: (3072, 1, 96, 3)

######################################################################
# Pointwise operators preserves memory format
z = x + y
print(z.stride())  # Ouputs: (3072, 1, 96, 3)

######################################################################
# Conv, Batchnorm modules using cudnn backends support channels last
# (only works for CudNN >= 7.6). Convolution modules, unlike binary
# p-wise operator, have channels last as the dominating memory format.
# IFF all inputs are in contiguous memory format, the operator
# produces output in contiguous memory format. Otherwise, output wil
# be in channels last memroy format.
Example #13
0
def linear_cg(
    matmul_closure,
    rhs,
    n_tridiag=0,
    tolerance=1e-6,
    eps=1e-20,
    max_iter=None,
    max_tridiag_iter=None,
    initial_guess=None,
    preconditioner=None,
):
    """
    Implements the linear conjugate gradients method for (approximately) solving systems of the form

        lhs result = rhs

    for positive definite and symmetric matrices.

    Args:
      - matmul_closure - a function which performs a left matrix multiplication with lhs_mat
      - rhs - the right-hand side of the equation
      - n_tridiag - returns a tridiagonalization of the first n_tridiag columns of rhs
      - tolerance - stop the solve when the max residual is less than this
      - eps - noise to add to prevent division by zero
      - max_iter - the maximum number of CG iterations
      - max_tridiag_iter - the maximum size of the tridiagonalization matrix
      - initial_guess - an initial guess at the solution `result`
      - precondition_closure - a functions which left-preconditions a supplied vector

    Returns:
      result - a solution to the system (if n_tridiag is 0)
      result, tridiags - a solution to the system, and corresponding tridiagonal matrices (if n_tridiag > 0)
    """
    # Unsqueeze, if necesasry
    is_vector = rhs.ndimension() == 1
    if is_vector:
        rhs = rhs.unsqueeze(-1)

    # Some default arguments
    if max_iter is None:
        max_iter = settings.max_cg_iterations.value()
    if max_tridiag_iter is None:
        max_tridiag_iter = settings.max_lanczos_quadrature_iterations.value()
    if initial_guess is None:
        initial_guess = torch.zeros_like(rhs)
    if preconditioner is None:
        preconditioner = _default_preconditioner

    # If we are running m CG iterations, we obviously can't get more than m Lanczos coefficients
    if max_tridiag_iter > max_iter:
        raise RuntimeError(
            "Getting a tridiagonalization larger than the number of CG iterations run is not possible!"
        )

    # Check matmul_closure object
    if torch.is_tensor(matmul_closure):
        matmul_closure = matmul_closure.matmul
    elif not callable(matmul_closure):
        raise RuntimeError(
            "matmul_closure must be a tensor, or a callable object!")

    # Get some constants
    batch_shape = rhs.shape[:-2]
    num_rows = rhs.size(-2)
    n_iter = min(max_iter,
                 num_rows) if settings.terminate_cg_by_size.on() else max_iter
    n_tridiag_iter = min(max_tridiag_iter, num_rows)

    # result <- x_{0}
    result = initial_guess

    # residual: residual_{0} = b_vec - lhs x_{0}
    residual = rhs - matmul_closure(result)

    # Check for NaNs
    if not torch.equal(residual, residual):
        raise RuntimeError(
            "NaNs encounterd when trying to perform matrix-vector multiplication"
        )

    # Sometime we're lucky and the preconditioner solves the system right away
    residual_norm = residual.norm(2, dim=-2)
    if (residual_norm < tolerance).all() and not n_tridiag:
        n_iter = 0  # Skip the iteration!

    # Otherwise, let's define precond_residual and curr_conjugate_vec
    else:
        # precon_residual{0} = M^-1 residual_{0}
        precond_residual = preconditioner(residual)
        curr_conjugate_vec = precond_residual
        residual_inner_prod = precond_residual.mul(residual).sum(-2,
                                                                 keepdim=True)

        # Define storage matrices
        mul_storage = torch.empty_like(residual)
        alpha = torch.empty(*batch_shape,
                            rhs.size(-1),
                            dtype=residual.dtype,
                            device=residual.device)
        beta = torch.empty_like(alpha)

    # Define tridiagonal matrices, if applicable
    if n_tridiag:
        t_mat = torch.zeros(n_tridiag_iter,
                            n_tridiag_iter,
                            *batch_shape,
                            n_tridiag,
                            dtype=alpha.dtype,
                            device=alpha.device)
        alpha_reciprocal = torch.empty(*batch_shape,
                                       n_tridiag,
                                       dtype=t_mat.dtype,
                                       device=t_mat.device)
        prev_alpha_reciprocal = torch.empty_like(alpha_reciprocal)
        prev_beta = torch.empty_like(alpha_reciprocal)

    update_tridiag = True
    last_tridiag_iter = 0
    # Start the iteration
    for k in range(n_iter):
        # Get next alpha
        # alpha_{k} = (residual_{k-1}^T precon_residual{k-1}) / (p_vec_{k-1}^T mat p_vec_{k-1})
        mvms = matmul_closure(curr_conjugate_vec)
        torch.mul(curr_conjugate_vec, mvms, out=mul_storage)
        torch.sum(mul_storage, -2, keepdim=True, out=alpha)
        alpha.add_(eps)
        torch.div(residual_inner_prod, alpha, out=alpha)

        # Update result
        # result_{k} = result_{k-1} + alpha_{k} p_vec_{k-1}
        torch.addcmul(result, alpha, curr_conjugate_vec, out=result)

        # Update residual
        # residual_{k} = residual_{k-1} - alpha_{k} mat p_vec_{k-1}
        torch.addcmul(residual, -1, alpha, mvms, out=residual)

        # If residual are sufficiently small, then exit loop
        # Alternatively, exit if this is our last iteration
        torch.norm(residual, 2, dim=-2, out=residual_norm)
        if (residual_norm < tolerance).all() and not (n_tridiag
                                                      and k < n_tridiag_iter):
            break

        # Update precond_residual
        # precon_residual{k} = M^-1 residual_{k}
        precond_residual = preconditioner(residual)

        # beta_{k} = (precon_residual{k}^T r_vec_{k}) / (precon_residual{k-1}^T r_vec_{k-1})
        residual_inner_prod.add_(eps)
        torch.reciprocal(residual_inner_prod, out=beta)
        torch.mul(residual, precond_residual, out=mul_storage)
        torch.sum(mul_storage, -2, keepdim=True, out=residual_inner_prod)
        beta.mul_(residual_inner_prod)

        # Update curr_conjugate_vec
        # curr_conjugate_vec_{k} = precon_residual{k} + beta_{k} curr_conjugate_vec_{k-1}
        curr_conjugate_vec.mul_(beta).add_(precond_residual)

        # Update tridiagonal matrices, if applicable
        if n_tridiag and k < n_tridiag_iter and update_tridiag:
            alpha_tridiag = alpha.squeeze_(-2).narrow(-1, 0, n_tridiag)
            beta_tridiag = beta.squeeze_(-2).narrow(-1, 0, n_tridiag)
            torch.reciprocal(alpha_tridiag, out=alpha_reciprocal)

            if k == 0:
                t_mat[k, k].copy_(alpha_reciprocal)
            else:
                torch.addcmul(alpha_reciprocal,
                              prev_beta,
                              prev_alpha_reciprocal,
                              out=t_mat[k, k])
                torch.mul(prev_beta.sqrt_(),
                          prev_alpha_reciprocal,
                          out=t_mat[k, k - 1])
                t_mat[k - 1, k].copy_(t_mat[k, k - 1])

                if t_mat[k - 1, k].max() < 1e-6:
                    update_tridiag = False

            last_tridiag_iter = k

            prev_alpha_reciprocal.copy_(alpha_reciprocal)
            prev_beta.copy_(beta_tridiag)

    if is_vector:
        result = result.squeeze(-1)

    if n_tridiag:
        t_mat = t_mat[:last_tridiag_iter + 1, :last_tridiag_iter + 1]
        return result, t_mat.permute(-1, *range(2, 2 + len(batch_shape)), 0,
                                     1).contiguous()
    else:
        return result
Example #14
0
    def get_mask(self, x, dropout):
        mask = torch.empty_like(x).bernoulli_(1 - dropout)
        mask = mask / (1 - dropout)

        return mask
 def randvec(self, x, norm=1):
     u = torch.randn(x.shape, out=torch.empty_like(x))
     u = self.proju(x, u, inplace=True)  # "transport" ``u`` to ``x``
     u.div_(u.norm(dim=-1, keepdim=True)).mul_(norm)  # normalize
     return u
Example #16
0
def ccrop_batch(imgs_tensor):
    ccropped_imgs = torch.empty_like(imgs_tensor)
    for i, img_ten in enumerate(imgs_tensor):
        ccropped_imgs[i] = ccrop(img_ten)

    return ccropped_imgs
Example #17
0
 def randvec(self, x, norm):
     # vector distributed uniformly on the sphere of radius ``norm`` around x
     u = torch.randn(x.shape, out=torch.empty_like(x))
     u.div_(u.norm(dim=self.dims, keepdim=True)).mul_(norm)
     return u
Example #18
0
    def step(self, closure=None):
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        # Evaluate averages and grad, update param tensors
        for group in self.param_groups:

            for p in group['params']:
                if p.grad is None:
                    continue
                grad = p.grad.data.float()

                if grad.is_sparse:
                    raise RuntimeError(
                        'Ranger optimizer does not support sparse gradients')

                p_data_fp32 = p.data.float()

                state = self.state[p]  # get state dict for this param

                if len(
                        state
                ) == 0:  # if first time to run...init dictionary with our desired entries
                    # if self.first_run_check==0:
                    # self.first_run_check=1
                    #print("Initializing slow buffer...should not see this at load from saved model!")
                    state['step'] = 0
                    state['exp_avg'] = torch.zeros_like(p_data_fp32)
                    state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)

                    # look ahead weight storage now in state dict
                    state['slow_buffer'] = torch.empty_like(p.data)
                    state['slow_buffer'].copy_(p.data)

                else:
                    state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
                    state['exp_avg_sq'] = state['exp_avg_sq'].type_as(
                        p_data_fp32)

                # begin computations
                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                beta1, beta2 = group['betas']

                # GC operation for Conv layers and FC layers
                # if grad.dim() > self.gc_gradient_threshold:
                #    grad.add_(-grad.mean(dim=tuple(range(1, grad.dim())), keepdim=True))
                if self.gc_loc:
                    grad = centralized_gradient(grad,
                                                use_gc=self.use_gc,
                                                gc_conv_only=self.gc_conv_only,
                                                dim=group['gc_dim'])

                state['step'] += 1

                # compute variance mov avg
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)

                # compute mean moving avg
                exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)

                buffered = self.radam_buffer[int(state['step'] % 10)]

                if state['step'] == buffered[0]:
                    N_sma, step_size = buffered[1], buffered[2]
                else:
                    buffered[0] = state['step']
                    beta2_t = beta2**state['step']
                    N_sma_max = 2 / (1 - beta2) - 1
                    N_sma = N_sma_max - 2 * \
                        state['step'] * beta2_t / (1 - beta2_t)
                    buffered[1] = N_sma
                    if N_sma > self.N_sma_threshhold:
                        step_size = math.sqrt(
                            (1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) *
                            (N_sma - 2) / N_sma * N_sma_max /
                            (N_sma_max - 2)) / (1 - beta1**state['step'])
                    else:
                        step_size = 1.0 / (1 - beta1**state['step'])
                    buffered[2] = step_size

                # if group['weight_decay'] != 0:
                #    p_data_fp32.add_(-group['weight_decay']
                #                     * group['lr'], p_data_fp32)

                # apply lr
                if N_sma > self.N_sma_threshhold:
                    denom = exp_avg_sq.sqrt().add_(group['eps'])
                    G_grad = exp_avg / denom
                else:
                    G_grad = exp_avg

                if group['weight_decay'] != 0:
                    G_grad.add_(p_data_fp32, alpha=group['weight_decay'])
                # GC operation
                if self.gc_loc == False:
                    G_grad = centralized_gradient(
                        G_grad,
                        use_gc=self.use_gc,
                        gc_conv_only=self.gc_conv_only,
                        dim=group['gc_dim'])

                p_data_fp32.add_(G_grad, alpha=-step_size * group['lr'])
                p.data.copy_(p_data_fp32)

                # integrated look ahead...
                # we do it at the param level instead of group level
                if state['step'] % group['k'] == 0:
                    # get access to slow param tensor
                    slow_p = state['slow_buffer']
                    # (fast weights - slow weights) * alpha
                    slow_p.add_(p.data - slow_p, alpha=self.alpha)
                    # copy interpolated weights to RAdam param tensor
                    p.data.copy_(slow_p)

        return loss
def highest_contrast(model,
                     test_loader,
                     num_concepts=5,
                     num_prototypes=9,
                     save_path=None):
    """Creates concept representation via highest contrast.

    The concepts are represented by the most data samples that are most specific to a concept.
    (The sample that yield the highest activation for each concept while at the same time
    not activating the other concepts)

    Parameters
    ----------
    model: torch.nn.Module
        The trained model with all its parameters.
    test_loader: DataLoader object
        Data loader that iterates over the test set.
    num_concepts: int
        Number of concepts of the model.
    num_prototypes: int
        Number of prototypical examples that should be displayed for each concept.
    save_path: str
        Path to the location where the bar plot should be saved.
    """
    model.eval()
    activations = []
    for x, _ in test_loader:
        x = x.float().to(
            "cuda:0" if next(model.parameters()).is_cuda else "cpu")
        with torch.no_grad():
            _, (concepts, _), _ = model(x)
            activations.append(concepts.squeeze())
    activations = torch.cat(activations)

    contrast_scores = torch.empty_like(activations)
    for c in range(num_concepts - 1):
        contrast_scores[:, c] = activations[:, c] - (
            activations[:, :c].sum(dim=1) + activations[:, c + 1:].sum(dim=1))
    contrast_scores[:, num_concepts -
                    1] = activations[:, num_concepts -
                                     1] - activations[:, :num_concepts -
                                                      1].sum(dim=1)

    _, top_test_idx = torch.topk(contrast_scores, num_prototypes, 0)

    top_examples = [
        test_loader.dataset.data[top_test_idx[:, concept]]
        for concept in range(num_concepts)
    ]
    # flatten list and ensure correct image shape
    top_examples = [
        img.unsqueeze(0) if len(img.size()) == 2 else img
        for sublist in top_examples for img in sublist
    ]

    plt.rcdefaults()
    fig, ax = plt.subplots()
    concept_names = ['Concept {}'.format(i + 1) for i in range(num_concepts)]

    start = 0.0
    end = num_concepts * x.size(-1)
    stepsize = abs(end - start) / num_concepts
    ax.yaxis.set_ticks(
        np.arange(start + 0.5 * stepsize, end - 0.49 * stepsize, stepsize))
    ax.set_yticklabels(concept_names)
    plt.xticks([])
    ax.set_xlabel('{} data examples with highest contrast per concept'.format(
        num_prototypes))
    ax.set_title('Concept Prototypes: ')
    save_or_show(make_grid(top_examples, nrow=num_prototypes, pad_value=1),
                 save_path)
    plt.rcdefaults()
Example #20
0
def extract_features(model,
                     data_loader,
                     dataset,
                     print_freq=10,
                     vlad=True,
                     pca=None,
                     gpu=None,
                     sync_gather=False):
    model.eval()
    batch_time = AverageMeter()
    data_time = AverageMeter()

    rank = dist.get_rank()
    world_size = dist.get_world_size()

    features = []

    if (pca is not None):
        pca.load()

    end = time.time()
    with torch.no_grad():
        for i, (imgs, fnames, _, _, _) in enumerate(data_loader):
            data_time.update(time.time() - end)

            outputs = extract_cnn_feature(model, imgs, vlad, gpu=gpu)
            if (pca is not None):
                outputs = pca.infer(outputs)
            outputs = outputs.data.cpu()

            features.append(outputs)

            batch_time.update(time.time() - end)
            end = time.time()

            if ((i + 1) % print_freq == 0 and rank == 0):
                print('Extract Features: [{}/{}]\t'
                      'Time {:.3f} ({:.3f})\t'
                      'Data {:.3f} ({:.3f})\t'.format(i + 1, len(data_loader),
                                                      batch_time.val,
                                                      batch_time.avg,
                                                      data_time.val,
                                                      data_time.avg))

    if (pca is not None):
        del pca

    if (sync_gather):
        # all gather features in parallel
        # cost more GPU memory but less time
        features = torch.cat(features).cuda(gpu)
        all_features = [torch.empty_like(features) for _ in range(world_size)]
        dist.all_gather(all_features, features)
        del features
        all_features = torch.cat(all_features).cpu()[:len(dataset)]
        features_dict = OrderedDict()
        for fname, output in zip(dataset, all_features):
            features_dict[fname[0]] = output
        del all_features
    else:
        # broadcast features in sequence
        # cost more time but less GPU memory
        bc_features = torch.cat(features).cuda(gpu)
        features_dict = OrderedDict()
        for k in range(world_size):
            bc_features.data.copy_(torch.cat(features))
            if (rank == 0):
                print("gathering features from rank no.{}".format(k))
            dist.broadcast(bc_features, k)
            l = bc_features.cpu().size(0)
            for fname, output in zip(dataset[k * l:(k + 1) * l],
                                     bc_features.cpu()):
                features_dict[fname[0]] = output
        del bc_features, features

    return features_dict
Example #21
0
    def __init__(
        self,
        n: Optional[int] = None,
        shape: Optional[Iterable[int]] = None,
        traces: bool = False,
        traces_additive: bool = False,
        tc_trace: Union[float, torch.Tensor] = 20.0,
        trace_scale: Union[float, torch.Tensor] = 1.0,
        sum_input: bool = False,
        thresh: Union[float, torch.Tensor] = -52.0,
        rest: Union[float, torch.Tensor] = -65.0,
        reset: Union[float, torch.Tensor] = -65.0,
        refrac: Union[int, torch.Tensor] = 5,
        tc_decay: Union[float, torch.Tensor] = 100.0,
        tc_i_decay: Union[float, torch.Tensor] = 2.0,
        lbound: float = None,
        **kwargs,
    ) -> None:
        # language=rst
        """
        Instantiates a layer of synaptic input current-based LIF neurons.
        :param n: The number of neurons in the layer.
        :param shape: The dimensionality of the layer.
        :param traces: Whether to record spike traces.
        :param traces_additive: Whether to record spike traces additively.
        :param tc_trace: Time constant of spike trace decay.
        :param trace_scale: Scaling factor for spike trace.
        :param sum_input: Whether to sum all inputs.
        :param thresh: Spike threshold voltage.
        :param rest: Resting membrane voltage.
        :param reset: Post-spike reset voltage.
        :param refrac: Refractory (non-firing) period of the neuron.
        :param tc_decay: Time constant of neuron voltage decay.
        :param tc_i_decay: Time constant of synaptic input current decay.
        :param lbound: Lower bound of the voltage.
        """
        super().__init__(
            n=n,
            shape=shape,
            traces=traces,
            traces_additive=traces_additive,
            tc_trace=tc_trace,
            trace_scale=trace_scale,
            sum_input=sum_input,
        )

        self.register_buffer("rest", torch.tensor(rest))  # Rest voltage.
        self.register_buffer("reset",
                             torch.tensor(reset))  # Post-spike reset voltage.
        self.register_buffer("thresh",
                             torch.tensor(thresh))  # Spike threshold voltage.
        self.register_buffer(
            "refrac", torch.tensor(refrac))  # Post-spike refractory period.
        self.register_buffer(
            "tc_decay",
            torch.tensor(tc_decay))  # Time constant of neuron voltage decay.
        self.register_buffer("decay", torch.empty_like(
            self.tc_decay))  # Set in compute_decays.
        self.register_buffer("tc_i_decay", torch.tensor(
            tc_i_decay))  # Time constant of synaptic input current decay.
        self.register_buffer("i_decay", torch.empty_like(
            self.tc_i_decay))  # Set in compute_decays.

        self.register_buffer("v", torch.FloatTensor())  # Neuron voltages.
        self.register_buffer("i",
                             torch.FloatTensor())  # Synaptic input currents.
        self.register_buffer(
            "refrac_count", torch.FloatTensor())  # Refractory period counters.

        self.lbound = lbound  # Lower bound of voltage.
Example #22
0
    def forward(self, input: GeometricTensor) -> GeometricTensor:
        r"""
        Apply norm non-linearities to the input feature map
        
        Args:
            input (GeometricTensor): the input feature map

        Returns:
            the resulting feature map
            
        """

        assert input.type == self.in_type

        input = input.tensor

        # scalar multipliers needed to turn the old norms into the newly computed ones
        multipliers = torch.empty_like(input)

        b, c, h, w = input.shape

        next_bias = 0

        if self.log_bias is not None:
            # build the bias
            # biases = torch.nn.functional.elu(self.log_bias)
            biases = torch.exp(self.log_bias)
            # biases = torch.nn.functional.elu(self.log_bias) + 1
        else:
            biases = None

        # iterate through all field sizes
        for s in self._order:

            # retrieve the corresponding fiber indices
            indices = getattr(self, f"indices_{s}")

            if self._contiguous[s]:
                # if the fields were contiguous, we can use slicing
                # retrieve the fields
                fm = input[:, indices[0]:indices[1], :, :]
            else:
                # otherwise we have to use indexing
                # retrieve the fields
                fm = input[:, indices, :, :]

            # compute the norm of each field
            norms = fm.view(b, -1, s, h, w).norm(dim=2, keepdim=True)

            # compute the new norms
            if biases is not None:
                # retrieve the bias elements corresponding to the current fields
                bias = biases[:, next_bias:next_bias + self._nfields[s],
                              ...].view(1, -1, 1, 1, 1)
                new_norms = self._function(norms - bias)
            else:
                new_norms = self._function(norms)

            # compute the scalar multipliers needed to turn the old norms into the newly computed ones
            # m = torch.zeros_like(new_norms)
            # in order to avoid division by 0
            # mask = norms > 0.
            # m[mask] = new_norms[mask] / norms[mask]

            m = new_norms / torch.max(norms, self.eps)
            m[norms <= self.eps] = 0.

            if self._contiguous[s]:
                # expand the multipliers tensor to all channels for each field
                multipliers[:, indices[0]:indices[1], :, :] = m.expand(
                    b, -1, s, h, w).reshape(b, -1, h, w)

            else:
                # expand the multipliers tensor to all channels for each field
                multipliers[:,
                            indices, :, :] = m.expand(b, -1, s, h,
                                                      w).reshape(b, -1, h, w)

            # shift the position on the bias tensor
            next_bias += self._nfields[s]

        # multiply the input by the multipliers computed and wrap the result in a GeometricTensor
        return GeometricTensor(input * multipliers, self.out_type)
Example #23
0
def hflip_batch(imgs_tensor):
    hfliped_imgs = torch.empty_like(imgs_tensor)
    for i, img_ten in enumerate(imgs_tensor):
        hfliped_imgs[i] = hflip(img_ten)

    return hfliped_imgs
Example #24
0
def train(local_rank,
          args,
          cache_state,
          data_files,
          end_dataloder,
          end_train,
          dist_training=True):

    setuplogger()
    try:
        if dist_training:
            init_process(local_rank, args.world_size)
        device = get_device()
        barrier = get_barrier(dist_training)

        news_info, news_combined = get_news_feature(args, mode='train')
        with only_on_main_process(local_rank, barrier) as need:
            if need:
                data_paths = []
                data_dirs = os.path.join(args.root_data_dir, 'train/')
                data_paths.extend(get_files(data_dirs, args.filename_pat))
                data_paths.sort()

        model = MLNR(args)
        if 'speedymind_ckpts' in args.pretrained_model_path:
            ckpt = torch.load(
                os.path.join(args.pretrained_model_path, 'pytorch_model.bin'))
            model.load_state_dict(ckpt['model_state_dict'])

        model = model.to(device)
        rest_param = filter(
            lambda x: id(x) not in list(
                map(id, model.news_encoder.unicoder.parameters())),
            model.parameters())
        optimizer = optim.Adam([
            {
                'params': model.news_encoder.unicoder.parameters(),
                'lr': args.pretrain_lr  #lr_schedule(args.pretrain_lr, 1, args)
            },
            {
                'params': rest_param,
                'lr': args.lr  #lr_schedule(args.lr, 1, args)
            }
        ])
        #

        if dist_training:
            ddp_model = DDP(model,
                            device_ids=[local_rank],
                            output_device=local_rank,
                            find_unused_parameters=True)
        else:
            ddp_model = model

        logging.info('Training...')
        start_time = time.time()
        test_time = 0.0
        global_step = 0
        best_count = 0
        optimizer.zero_grad()

        loss = 0.0
        best_auc = 0.0
        accuary = 0.0
        hit_num = 0
        all_num = 1
        encode_num = 0
        cache = np.zeros((len(news_combined), args.news_dim))
        for ep in range(args.epochs):
            with only_on_main_process(local_rank, barrier) as need:
                if need:
                    while len(data_files) > 0:
                        data_files.pop()
                    data_files.extend(data_paths)
                    random.shuffle(data_files)
            barrier()

            dataloader = DataLoaderTrainForSpeedyRec(
                args=args,
                data_files=data_files,
                cache_state=cache_state,
                end=end_dataloder,
                local_rank=local_rank,
                world_size=args.world_size,
                news_features=news_combined,
                news_index=news_info.news_index,
                enable_prefetch=args.enable_prefetch,
                enable_prefetch_stream=args.enable_prefetch_stream,
                global_step=global_step,
                add_pad_news=True)

            ddp_model.train()
            pad_doc = torch.zeros(1, args.news_dim, device=device)

            for cnt, batch in tqdm(enumerate(dataloader)):
                with torch.autograd.set_detect_anomaly(True):
                    address_cache, update_cache, satrt_inx, end_inx, batch = batch
                    global_step += 1

                    if args.enable_gpu:
                        input_ids, hist_sequence, hist_sequence_mask, candidate_inx, label_batch = (
                            x.cuda(device=device, non_blocking=True)
                            if x is not None else x for x in batch[:5])
                    else:
                        input_ids, hist_sequence, hist_sequence_mask, candidate_inx, label_batch = batch[:
                                                                                                         5]

                    encode_num += input_ids.size(0)

                    # Get news vecs from cache.
                    if address_cache is not None:
                        # cache_vec = [cache[inx] for inx in address_cache]
                        cache_vec = cache[address_cache]
                        cache_vec = torch.FloatTensor(cache_vec).cuda(
                            device=device, non_blocking=True)

                        # atime += time.time() - temp_stime
                        hit_num += cache_vec.size(0)
                        all_num += cache_vec.size(0)

                    else:
                        cache_vec = None
                        hit_num += 0

                    if cache_vec is not None:
                        cache_vec = torch.cat([pad_doc, cache_vec], 0)
                    else:
                        cache_vec = pad_doc

                    if input_ids.size(0) > 0:
                        if dist_training:
                            encode_vecs = ddp_model.module.news_encoder(
                                input_ids)
                        else:
                            encode_vecs = ddp_model.news_encoder(input_ids)
                    else:
                        encode_vecs = None

                    all_tensors = [
                        torch.empty_like(encode_vecs)
                        for _ in range(args.world_size)
                    ]
                    dist.all_gather(all_tensors, encode_vecs)
                    all_tensors[local_rank] = encode_vecs
                    all_encode_vecs = torch.cat(all_tensors, dim=0)
                    news_vecs = torch.cat([cache_vec, all_encode_vecs], 0)

                    all_num += all_encode_vecs.size(0)
                    bz_loss, y_hat = ddp_model(news_vecs, hist_sequence,
                                               hist_sequence_mask,
                                               candidate_inx, label_batch)

                    loss += bz_loss.item()
                    bz_loss.backward()
                    optimizer.step()
                    optimizer.zero_grad()

                    accuary += acc(label_batch, y_hat)

                    # update the cache
                    if args.max_step_in_cache > 0 and encode_vecs is not None:
                        update_vecs = all_encode_vecs.detach().cpu().numpy(
                        )[:len(update_cache)]
                        cache[update_cache] = update_vecs

                    optimizer.param_groups[0]['lr'] = lr_schedule(
                        args.pretrain_lr, global_step, args)
                    optimizer.param_groups[1]['lr'] = lr_schedule(
                        args.lr, global_step, args)

                    barrier()

                if global_step % args.log_steps == 0:
                    logging.info(
                        '[{}] cost_time:{} step:{}, train_loss: {:.5f}, acc:{:.5f}, hit:{}, encode_num:{}, lr:{:.8f}, pretrain_lr:{:.8f}'
                        .format(local_rank,
                                time.time() - start_time - test_time,
                                global_step, loss / args.log_steps,
                                accuary / args.log_steps, hit_num / all_num,
                                encode_num, optimizer.param_groups[1]['lr'],
                                optimizer.param_groups[0]['lr']))
                    loss = 0.0
                    accuary = 0.0

                if global_step % args.test_steps == 0 and local_rank == 0:
                    stest_time = time.time()
                    auc = test(model, args, device, news_info.category_dict,
                               news_info.subcategory_dict)
                    ddp_model.train()
                    logging.info('step:{}, auc:{}'.format(global_step, auc))
                    test_time = test_time + time.time() - stest_time

                # save model minibatch
                if local_rank == 0 and global_step % args.save_steps == 0:
                    ckpt_path = os.path.join(
                        args.model_dir,
                        f'{args.savename}-epoch-{ep + 1}-{global_step}.pt')
                    torch.save(
                        {
                            'model_state_dict': model.state_dict(),
                            'optimizer': optimizer.state_dict(),
                            'category_dict': news_info.category_dict,
                            'subcategory_dict': news_info.subcategory_dict,
                        }, ckpt_path)
                    logging.info(f"Model saved to {ckpt_path}")

            logging.info('epoch:{}, time:{}, encode_num:{}'.format(
                ep + 1,
                time.time() - start_time - test_time, encode_num))
            # save model after an epoch
            if local_rank == 0:
                ckpt_path = os.path.join(
                    args.model_dir,
                    '{}-epoch-{}.pt'.format(args.savename, ep + 1))
                torch.save(
                    {
                        'model_state_dict': model.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'category_dict': news_info.category_dict,
                        'subcategory_dict': news_info.subcategory_dict,
                    }, ckpt_path)
                logging.info(f"Model saved to {ckpt_path}")

                auc = test(model, args, device, news_info.category_dict,
                           news_info.subcategory_dict)
                ddp_model.train()

                if auc > best_auc:
                    best_auc = auc
                else:
                    best_count += 1
                    if best_auc >= 3:
                        logging.info("best_auc:{}, best_ep:{}".format(
                            best_auc, ep - 3))
                        end_train.value = True
            barrier()
            if end_train.value:
                break

        if dist_training:
            cleanup_process()

    except:
        error_type, error_value, error_trace = sys.exc_info()
        traceback.print_tb(error_trace)
        logging.info(error_value)
Example #25
0
    def step(self, closure=None):
        loss = None
        # note - below is commented out b/c I have other work that passes back the loss as a float, and thus not a callable closure.
        # Uncomment if you need to use the actual closure...

        # if closure is not None:
        #loss = closure()

        # Evaluate averages and grad, update param tensors
        for group in self.param_groups:

            for p in group['params']:
                if p.grad is None:
                    continue

                # Perform stepweight decay
                p.mul_(1 - group['lr'] * group['weight_decay'])

                grad = p.grad

                if grad.is_sparse:
                    raise RuntimeError(
                        'Ranger optimizer does not support sparse gradients')

                p_data_fp32 = p.data

                state = self.state[p]  # get state dict for this param

                if len(
                        state
                ) == 0:  # if first time to run...init dictionary with our desired entries
                    # if self.first_run_check==0:
                    # self.first_run_check=1
                    #print("Initializing slow buffer...should not see this at load from saved model!")
                    state['step'] = 0
                    state['exp_avg'] = torch.zeros_like(p_data_fp32)
                    state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)

                    # look ahead weight storage now in state dict
                    state['slow_buffer'] = torch.empty_like(p.data)
                    state['slow_buffer'].copy_(p.data)

                else:
                    state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
                    state['exp_avg_sq'] = state['exp_avg_sq'].type_as(
                        p_data_fp32)

                # begin computations
                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                beta1, beta2 = group['betas']

                state['step'] += 1

                # compute variance mov avg
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)

                # compute mean moving avg
                exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)

                buffered = self.radam_buffer[int(state['step'] % 10)]

                if state['step'] == buffered[0]:
                    N_sma, step_size = buffered[1], buffered[2]
                else:
                    buffered[0] = state['step']
                    beta2_t = beta2**state['step']
                    N_sma_max = 2 / (1 - beta2) - 1
                    N_sma = N_sma_max - 2 * \
                        state['step'] * beta2_t / (1 - beta2_t)
                    buffered[1] = N_sma
                    if N_sma > self.N_sma_threshhold:
                        step_size = math.sqrt(
                            (1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) *
                            (N_sma - 2) / N_sma * N_sma_max /
                            (N_sma_max - 2)) / (1 - beta1**state['step'])
                    else:
                        step_size = 1.0 / (1 - beta1**state['step'])
                    buffered[2] = step_size

                # apply lr
                if N_sma > self.N_sma_threshhold:
                    denom = exp_avg_sq.sqrt().add_(group['eps'])
                    G_grad = exp_avg / denom
                else:
                    G_grad = exp_avg

                p_data_fp32.add_(G_grad, alpha=-step_size * group['lr'])
                p.data.copy_(p_data_fp32)

                # integrated look ahead...
                # we do it at the param level instead of group level
                if state['step'] % group['k'] == 0:
                    # get access to slow param tensor
                    slow_p = state['slow_buffer']
                    # (fast weights - slow weights) * alpha
                    slow_p.add_(p.data - slow_p, alpha=self.alpha)
                    # copy interpolated weights to RAdam param tensor
                    p.data.copy_(slow_p)

        return loss
Example #26
0
    def step(self, closure=None):
        loss = None

        if closure is not None:
            loss = closure()

        # Evaluate averages and grad, update param tensors
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is not None:
                    grad = p.grad.data.float()
                    if grad.is_sparse:
                        raise RuntimeError(
                            'Ranger optimizer does not support sparse gradients'
                        )

                    p_data_fp32 = p.data.float()

                    state = self.state[p]  # get state dict for this param

                    # On the first run initialize the dictionary for each weight group
                    if len(state) == 0:
                        state['step'] = 0
                        state['exp_avg'] = torch.zeros_like(p_data_fp32)
                        state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)

                        # look ahead weight storage now in state dict
                        state['slow_buffer'] = torch.empty_like(p.data)
                        state['slow_buffer'].copy_(p.data)
                    # @TODO Couldn't this branch happen after the if above is entered
                    # in thus replacing torch.zero_like) ??
                    else:
                        state['exp_avg'] = state['exp_avg'].type_as(
                            p_data_fp32)
                        state['exp_avg_sq'] = state['exp_avg_sq'].type_as(
                            p_data_fp32)

                # begin computations
                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                beta1, beta2 = group['betas']

                # compute variance mov avg
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
                # compute mean moving avg
                exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)

                state['step'] += 1

                buffered = self.radam_buffer[int(state['step'] % 10)]

                if state['step'] == buffered[0]:
                    N_sma, step_size = buffered[1], buffered[2]
                else:
                    buffered[0] = state['step']
                    beta2_t = beta2**state['step']
                    N_sma_max = 2 / (1 - beta2) - 1
                    N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 -
                                                                       beta2_t)
                    buffered[1] = N_sma
                    if N_sma > self.N_sma_threshold:
                        step_size = math.sqrt(
                            (1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) *
                            (N_sma - 2) / N_sma * N_sma_max /
                            (N_sma_max - 2)) / (1 - beta1**state['step'])
                    else:
                        step_size = 1.0 / (1 - beta1**state['step'])
                    buffered[2] = step_size

                if group['weight_decay'] != 0:
                    p_data_fp32.add_(-group['weight_decay'] * group['lr'],
                                     p_data_fp32)

                if N_sma > self.N_sma_threshold:
                    denom = exp_avg_sq.sqrt().add_(group['eps'])
                    p_data_fp32.addcdiv_(exp_avg,
                                         denom,
                                         value=-step_size * group['lr'])
                else:
                    p_data_fp32.add_(exp_avg, alpha=-step_size * group['lr'])

                p.data.copy_(p_data_fp32)

                # integrated look ahead...
                # we do it at the param level instead of group level
                if state['step'] % group['k'] == 0:
                    slow_p = state['slow_buffer']
                    # Find the interpolated weight between the slower buffer (the weight `k` steps ago)
                    # and the current weight, set that as the state for RAdam
                    slow_p.add_(p.data - slow_p, alpha=self.alpha)
                    p.data.copy_(slow_p)

        return loss
Example #27
0
def vtln_warp_freq(vtln_low_cutoff: float, vtln_high_cutoff: float,
                   low_freq: float, high_freq: float, vtln_warp_factor: float,
                   freq: Tensor) -> Tensor:
    r"""This computes a VTLN warping function that is not the same as HTK's one,
    but has similar inputs (this function has the advantage of never producing
    empty bins).

    This function computes a warp function F(freq), defined between low_freq
    and high_freq inclusive, with the following properties:
        F(low_freq) == low_freq
        F(high_freq) == high_freq
    The function is continuous and piecewise linear with two inflection
        points.
    The lower inflection point (measured in terms of the unwarped
        frequency) is at frequency l, determined as described below.
    The higher inflection point is at a frequency h, determined as
        described below.
    If l <= f <= h, then F(f) = f/vtln_warp_factor.
    If the higher inflection point (measured in terms of the unwarped
        frequency) is at h, then max(h, F(h)) == vtln_high_cutoff.
        Since (by the last point) F(h) == h/vtln_warp_factor, then
        max(h, h/vtln_warp_factor) == vtln_high_cutoff, so
        h = vtln_high_cutoff / max(1, 1/vtln_warp_factor).
          = vtln_high_cutoff * min(1, vtln_warp_factor).
    If the lower inflection point (measured in terms of the unwarped
        frequency) is at l, then min(l, F(l)) == vtln_low_cutoff
        This implies that l = vtln_low_cutoff / min(1, 1/vtln_warp_factor)
                            = vtln_low_cutoff * max(1, vtln_warp_factor)
    Args:
        vtln_low_cutoff (float): Lower frequency cutoffs for VTLN
        vtln_high_cutoff (float): Upper frequency cutoffs for VTLN
        low_freq (float): Lower frequency cutoffs in mel computation
        high_freq (float): Upper frequency cutoffs in mel computation
        vtln_warp_factor (float): Vtln warp factor
        freq (Tensor): given frequency in Hz

    Returns:
        Tensor: Freq after vtln warp
    """
    assert vtln_low_cutoff > low_freq, 'be sure to set the vtln_low option higher than low_freq'
    assert vtln_high_cutoff < high_freq, 'be sure to set the vtln_high option lower than high_freq [or negative]'
    l = vtln_low_cutoff * max(1.0, vtln_warp_factor)
    h = vtln_high_cutoff * min(1.0, vtln_warp_factor)
    scale = 1.0 / vtln_warp_factor
    Fl = scale * l  # F(l)
    Fh = scale * h  # F(h)
    assert l > low_freq and h < high_freq
    # slope of left part of the 3-piece linear function
    scale_left = (Fl - low_freq) / (l - low_freq)
    # [slope of center part is just "scale"]

    # slope of right part of the 3-piece linear function
    scale_right = (high_freq - Fh) / (high_freq - h)

    res = torch.empty_like(freq)

    outside_low_high_freq = torch.lt(freq, low_freq) | torch.gt(
        freq, high_freq)  # freq < low_freq || freq > high_freq
    before_l = torch.lt(freq, l)  # freq < l
    before_h = torch.lt(freq, h)  # freq < h
    after_h = torch.ge(freq, h)  # freq >= h

    # order of operations matter here (since there is overlapping frequency regions)
    res[after_h] = high_freq + scale_right * (freq[after_h] - high_freq)
    res[before_h] = scale * freq[before_h]
    res[before_l] = low_freq + scale_left * (freq[before_l] - low_freq)
    res[outside_low_high_freq] = freq[outside_low_high_freq]

    return res
Example #28
0
    def __init__(
        self,
        n: Optional[int] = None,
        shape: Optional[Iterable[int]] = None,
        traces: bool = False,
        traces_additive: bool = False,
        tc_trace: Union[float, torch.Tensor] = 20.0,
        trace_scale: Union[float, torch.Tensor] = 1.0,
        sum_input: bool = False,
        rest: Union[float, torch.Tensor] = -60.0,
        reset: Union[float, torch.Tensor] = -45.0,
        thresh: Union[float, torch.Tensor] = -40.0,
        tc_decay: Union[float, torch.Tensor] = 10.0,
        R: Union[float, torch.Tensor] = 32,
        tau_inc: Union[float, torch.Tensor] = 10.,
        tau_dec: Union[float, torch.Tensor] = 5.,
        **kwargs,
    ) -> None:
        # language=rst
        """
        Instantiates a layer of Hao & Huang(2019) SL neurons.

        :param n: The number of neurons in the layer.
        :param shape: The dimensionality of the layer.
        :param traces: Whether to record spike traces.
        :param traces_additive: Whether to record spike traces additively.
        :param tc_trace: Time constant of spike trace decay.
        :param trace_scale: Scaling factor for spike trace.
        :param sum_input: Whether to sum all inputs.
        :param rest: Resting membrane voltage.
        :param reset: Post-spike reset voltage.
        :param thresh: Spike threshold voltage.
        :param tc_decay: Time constant of neuron voltage decay.
        """
        super().__init__(
            n=n,
            shape=shape,
            traces=traces,
            traces_additive=traces_additive,
            tc_trace=tc_trace,
            trace_scale=trace_scale,
            sum_input=sum_input,
        )

        self.register_buffer("rest", torch.tensor(rest))  # Rest voltage.
        self.register_buffer("reset",
                             torch.tensor(reset))  # Post-spike reset voltage.
        self.register_buffer("thresh",
                             torch.tensor(thresh))  # Spike threshold voltage.
        self.register_buffer(
            "tc_decay",
            torch.tensor(tc_decay))  # Time constant of neuron voltage decay.
        self.register_buffer("decay", torch.empty_like(
            self.tc_decay))  # Set in compute_decays.
        self.register_buffer("I", torch.FloatTensor())
        self.register_buffer("X", torch.FloatTensor())
        self.register_buffer("tau_inc", torch.tensor(tau_inc))
        self.register_buffer("tau_dec", torch.tensor(tau_dec))
        self.register_buffer("I_decay", torch.empty_like(self.tau_dec))
        self.register_buffer("X_decay", torch.empty_like(self.tau_dec))
        self.register_buffer("C", torch.empty_like(self.tau_dec))
        self.register_buffer("R", torch.tensor(R))
        self.register_buffer("v", torch.FloatTensor())
Example #29
0
def integer_dropout(t: torch.Tensor, fill_value: int,
                    p: float) -> torch.Tensor:
    mask = torch.empty_like(t, dtype=torch.bool).bernoulli_(p)
    return t.masked_fill(mask, fill_value)
Example #30
0
 def randvec(self, x, norm):
     u = torch.randn(x.shape, out=torch.empty_like(x))
     u = self.proju(x, u)  # "transport" ``u`` to ``x``
     u.div_(u.norm(dim=(-2, -1), keepdim=True)).mul_(norm)
     return u
Example #31
0
 def test_empty_like(self):
     x = torch.randn(5, 8, requires_grad=True)
     self.assertONNX(lambda x: torch.empty_like(x), x)
Example #32
0
 def test_empty_like_opset7(self):
     x = torch.randn(5, 8, requires_grad=True)
     self.assertONNX(lambda x: torch.empty_like(x), x, opset_version=7)
Example #33
0
def hflip_batch(imgs_tensor):
    hfliped_imgs = torch.empty_like(imgs_tensor)
    for i, img_ten in enumerate(imgs_tensor):
        hfliped_imgs[i] = hflip(img_ten)

    return hfliped_imgs
Example #34
0
    def forward(self,
                input,
                input_mask=None,
                attention_mask=None,
                head_mask=None,
                layer_past=None,
                get_key_value=False,
                get_present=False,
                encoder_output=None,
                enc_dec_attn_mask=None,
                encoder_hidden_states=None,
                encoder_attention_mask=None,
                use_cache=False,
                output_attentions=False):
        get_present = (get_present or get_key_value or use_cache)
        input_mask = input_mask if attention_mask is None else attention_mask
        input_type = input.dtype

        if (self.config.fp16 or self.config.q_int8) \
            and input.dtype == torch.float:
            input = input.half()

        with torch.no_grad():
            attention_output = self.attention(input, input_mask, head_mask,
                                              layer_past, get_present,
                                              encoder_hidden_states,
                                              encoder_attention_mask,
                                              output_attentions, self.norm_w,
                                              self.norm_b)

            if get_present:
                attention_output, p_key, p_value = attention_output[0:3]
                presents = (p_key, p_value)
            elif output_attentions:
                attention_output, _, _, context_output = attention_output[0:4]
            else:
                attention_output = attention_output[0]

            residual_add = attention_output + self.attention.attn_ob
            attention_output = self.ds_layernorm(residual_add, self.attn_nw,
                                                 self.attn_nb,
                                                 self.config.epsilon)

            if self.config.mlp_type == 'residual':
                res_mlp_out = self.res_mlp(attention_output, async_op=True)
                res_coef_out = self.res_coef_func(attention_output,
                                                  async_op=True)

            if self.expert_mp_group is not None:
                tensor_list = [
                    torch.empty_like(attention_output) for _ in range(
                        dist.get_world_size(group=self.expert_mp_group))
                ]
                tensor_list[dist.get_rank(
                    group=self.expert_mp_group)] = attention_output
                dist.all_gather(tensor_list,
                                attention_output,
                                group=self.expert_mp_group)
                attention_output = torch.cat(tensor_list).contiguous()

            ############## MoE Gating + Experts ###############
            dispatched_attention, combined_weights = self.moe_gate_einsum(
                attention_output)
            dispatched_input = self._alltoall(dispatched_attention)
            expert_outputs = self.expert_exec(dispatched_input)
            expert_output = self._alltoall(expert_outputs)
            output = self.scale_expert_output(attention_output, expert_output,
                                              combined_weights)
            ################################################

            if self.expert_mp_group is not None:
                output = output.split(
                    output.shape[0] //
                    dist.get_world_size(group=self.expert_mp_group),
                    dim=0)[dist.get_rank(group=self.expert_mp_group)]

            if self.config.mlp_type == 'residual':
                inference_cuda_module.moe_res_matmul(res_mlp_out, res_coef_out,
                                                     output)

            output = self.bias_residual_func(output, residual_add,
                                             torch.empty(1))

            if not self.config.pre_layer_norm:
                output = self.ds_layernorm(output, self.norm_w, self.norm_b,
                                           self.config.epsilon)

            if input_type != output.dtype:
                output = output.to(input_type)

        if get_present:
            output = (output, presents)

        if self.config.return_tuple:
            return output if type(output) is tuple else (output, )
        else:
            return output