Exemplo n.º 1
0
def load_checkpoint(path):
    print("Loading checkpoint...")
    restore = path
    if restore[:5] == 'gs://':
        gs_path = restore
        local_path = os.path.join(os.path.expanduser("~/.cache"), gs_path[5:])
        gdrive_path = os.path.join("/content/gdrive/My Drive/samples/", gs_path[5:])
        print(f'local path: {local_path}')
        print(f'gdrive path: {gdrive_path}')
        if dist.get_rank() % 8 == 0:
            if os.path.exists(gdrive_path):
                print("Using priors on Google Drive")
                restore = gdrive_path
            elif os.path.exists( os.path.dirname(gdrive_path) ):
                print("Downloading priors to Google Drive")
                download(gs_path, gdrive_path)
                restore = gdrive_path
            else:
                print("Downloading from gce")
                if not os.path.exists(os.path.dirname(local_path)):
                    os.makedirs(os.path.dirname(local_path))
                if not os.path.exists(local_path):
                    download(gs_path, local_path)
                restore = local_path
    dist.barrier()
    checkpoint = t.load(restore, map_location=t.device('cpu'))
    print("RS // Restored from {}".format(restore))
    return checkpoint
Exemplo n.º 2
0
def get_optimizer(model, hps):
    # Optimizer
    betas = (hps.beta1, hps.beta2)
    if hps.fp16_opt:
        opt = FP16FusedAdam(model.parameters(),
                            lr=hps.lr,
                            weight_decay=hps.weight_decay,
                            betas=betas,
                            eps=hps.eps)
    else:
        opt = FusedAdam(model.parameters(),
                        lr=hps.lr,
                        weight_decay=hps.weight_decay,
                        betas=betas,
                        eps=hps.eps)

    # lr scheduler
    shd = get_lr_scheduler(opt, hps)

    # fp16 dynamic loss scaler
    scalar = None
    if hps.fp16:
        rank = dist.get_rank()
        local_rank = rank % 8
        scalar = LossScalar(hps.fp16_loss_scale,
                            scale_factor=2**(1. / hps.fp16_scale_window))
        if local_rank == 0: print(scalar.__dict__)

    zero_grad(model)
    return opt, shd, scalar
Exemplo n.º 3
0
    def __init__(self,
                 n_in,
                 n_depth,
                 m_conv=1.0,
                 dilation_growth_rate=1,
                 dilation_cycle=None,
                 zero_out=False,
                 res_scale=False,
                 reverse_dilation=False,
                 checkpoint_res=False):
        super().__init__()

        def _get_depth(depth):
            if dilation_cycle is None:
                return depth
            else:
                return depth % dilation_cycle

        blocks = [
            ResConv1DBlock(n_in,
                           int(m_conv * n_in),
                           dilation=dilation_growth_rate**_get_depth(depth),
                           zero_out=zero_out,
                           res_scale=1.0 if not res_scale else 1.0 /
                           math.sqrt(n_depth)) for depth in range(n_depth)
        ]
        if reverse_dilation:
            blocks = blocks[::-1]
        self.checkpoint_res = checkpoint_res
        if self.checkpoint_res == 1:
            if dist.get_rank() == 0:
                print("Checkpointing convs")
            self.blocks = nn.ModuleList(blocks)
        else:
            self.model = nn.Sequential(*blocks)
Exemplo n.º 4
0
def get_ddp(model, hps):
    rank = dist.get_rank()
    local_rank = rank % 8
    ddp = DistributedDataParallel(model,
                                  device_ids=[local_rank],
                                  output_device=local_rank,
                                  broadcast_buffers=False,
                                  bucket_cap_mb=hps.bucket)
    return ddp
Exemplo n.º 5
0
    def init_dataset(self, hps):
        # Load list of files and starts/durations
        files = librosa.util.find_files(f'{hps.audio_files_dir}', ['mp3', 'opus', 'm4a', 'aac', 'wav'])
        print_all(f"Found {len(files)} files. Getting durations")
        cache = dist.get_rank() % 8 == 0 if dist.is_available() else True
        durations = np.array([get_duration_sec(file, cache=cache) * self.sr for file in files])  # Could be approximate
        self.filter(files, durations)

        if self.labels:
            self.labeller = Labeller(hps.max_bow_genre_size, hps.n_tokens, self.sample_length, v3=hps.labels_v3)
Exemplo n.º 6
0
 def lr_lambda(step):
     if hps.lr_use_linear_decay:
         lr_scale = hps.lr_scale * min(1.0, step / hps.lr_warmup)
         decay = max(0.0, 1.0 - max(0.0, step - hps.lr_start_linear_decay) / hps.lr_decay)
         if decay == 0.0:
             if dist.get_rank() == 0:
                 print("Reached end of training")
         return lr_scale * decay
     else:
         return hps.lr_scale * (hps.lr_gamma ** (step // hps.lr_decay)) * min(1.0, step / hps.lr_warmup)
Exemplo n.º 7
0
def get_ema(model, hps):
    mu = hps.mu or (1. - (hps.bs * hps.ngpus/8.)/1000)
    ema = None
    if hps.ema and hps.train:
        if hps.cpu_ema:
            if dist.get_rank() == 0:
                print("Using CPU EMA")
            ema = CPUEMA(model.parameters(), mu=mu, freq=hps.cpu_ema_freq)
        elif hps.ema_fused:
            ema = FusedEMA(model.parameters(), mu=mu)
        else:
            ema = EMA(model.parameters(), mu=mu)
    return ema
Exemplo n.º 8
0
def _setup_dist_from_mpi(master_addr, backend, port, n_attempts, verbose):
    from mpi4py import MPI  # This must be imported in order to get e   rrors from all ranks to show up

    mpi_rank = MPI.COMM_WORLD.Get_rank()
    mpi_size = MPI.COMM_WORLD.Get_size()

    os.environ["RANK"] = str(mpi_rank)
    os.environ["WORLD_SIZE"] = str(mpi_size)
    os.environ["MASTER_ADDR"] = master_addr
    os.environ["MASTER_PORT"] = str(port)
    os.environ["NCCL_LL_THRESHOLD"] = "0"
    os.environ["NCCL_NSOCKS_PERTHREAD"] = "2"
    os.environ["NCCL_SOCKET_NTHREADS"] = "8"

    # Pin this rank to a specific GPU on the node
    local_rank = mpi_rank % 8
    if torch.cuda.is_available():
        torch.cuda.set_device(local_rank)

    print(f'hmm {torch.cuda.is_available()}')

    if verbose:
        print(f"Connecting to master_addr: {master_addr}")

    # There is a race condition when initializing NCCL with a large number of ranks (e.g 500 ranks)
    # We guard against the failure and then retry
    for attempt_idx in range(n_attempts):
        try:
            dist.init_process_group(backend=backend, init_method=f"env://")
            assert dist.get_rank() == mpi_rank

            use_cuda = torch.cuda.is_available()
            print(f'Using cuda {use_cuda}')
            local_rank = mpi_rank % 8
            device = torch.device(
                "cuda", local_rank) if use_cuda else torch.device("cpu")
            torch.cuda.set_device(
                local_rank) if torch.cuda.is_available() else {}

            return mpi_rank, local_rank, device
        except RuntimeError as e:
            print(
                f"Caught error during NCCL init (attempt {attempt_idx} of {n_attempts}): {e}"
            )
            sleep(1 + (0.01 * mpi_rank))  # Sleep to avoid thundering herd
            pass

    raise RuntimeError("Failed to initialize NCCL")
Exemplo n.º 9
0
def load_checkpoint(path):
    restore = path
    if restore[:5] == 'gs://':
        gs_path = restore
        local_path = os.path.join(os.path.expanduser("~/.cache"), gs_path[5:])
        if dist.get_rank() % 8 == 0:
            print("Downloading from gce")
            if not os.path.exists(os.path.dirname(local_path)):
                os.makedirs(os.path.dirname(local_path))
            if not os.path.exists(local_path):
                download(gs_path, local_path)
        restore = local_path
    dist.barrier()
    checkpoint = t.load(restore, map_location=t.device('cpu'))
    print("Restored from {}".format(restore))
    return checkpoint
Exemplo n.º 10
0
def load_checkpoint(path):
    restore = path
    if restore.startswith(REMOTE_PREFIX):
        remote_path = restore
        local_path = os.path.join(os.path.expanduser("~/.cache"), remote_path[len(REMOTE_PREFIX):])
        if dist.get_rank() % 8 == 0:
            print("Downloading from azure")
            if not os.path.exists(os.path.dirname(local_path)):
                os.makedirs(os.path.dirname(local_path))
            if not os.path.exists(local_path):
                download(remote_path, local_path)
        restore = local_path
    dist.barrier()
    checkpoint = custom_load(restore, map_location=t.device('cpu'))
    print("Restored from {}".format(restore))
    return checkpoint
Exemplo n.º 11
0
    def sample(self, n_samples, z=None, memory=None, z_conds=None, y=None, fp16=False, temp=1.0, top_k=0, top_p=0.0,
               chunk_size=None, sample_tokens=None, train=False):
        N = n_samples
        if z is not None: assert z.shape[0] == N, f"Expected shape ({N},**), got shape {z.shape}"
        if y is not None: assert y.shape[0] == N, f"Expected shape ({N},**), got shape {y.shape}"
        if z_conds is not None:
            for z_cond in z_conds:
                assert z_cond.shape[0] == N,  f"Expected shape ({N},**), got shape {z_cond.shape}"

        no_past_context = (z is None or z.shape[1] == 0)
        if dist.get_rank() == 0:
            name = {True: 'Ancestral', False: 'Primed'}[no_past_context]
            print(f"{name} sampling {n_samples} samples with temp={temp}, top_k={top_k}, top_p={top_p}")

        with t.no_grad():
            # Currently x_cond only uses immediately above layer
            x_cond, y_cond, prime = self.get_cond(z_conds, y)
            if self.single_enc_dec:
                # assert chunk_size % self.prime_loss_dims == 0. TODO: Check if needed
                if no_past_context:
                    z, x_cond = self.prior_preprocess([prime], [None, x_cond])
                else:
                    z, x_cond = self.prior_preprocess([prime, z], [None, x_cond])
                if sample_tokens is not None:
                    sample_tokens += self.n_tokens
                z = self.prior.primed_sample(n_samples, z, x_cond, y_cond, fp16=fp16, temp=temp,
                                             top_k=top_k, top_p=top_p, chunk_size=chunk_size, sample_tokens=sample_tokens)
                z = self.prior_postprocess(z)
            else:
                encoder_kv = self.get_encoder_kv(prime, fp16=fp16, sample=True)
                if no_past_context:
                    if train:
                      z = self.prior.sample(n_samples, x_cond, y_cond, encoder_kv, fp16=fp16, temp=temp, top_k=top_k,
                                         top_p=top_p, sample_tokens=sample_tokens)
                    else:
                      z, memory = self.prior.sample_recurrent(n_samples, x_cond, y_cond, encoder_kv, fp16=fp16, temp=temp, top_k=top_k,
                                          top_p=top_p, sample_tokens=sample_tokens)
                else:
                    z, memory = self.prior.primed_sample(n_samples, z, memory, x_cond, y_cond, encoder_kv, fp16=fp16, temp=temp,
                                             top_k=top_k, top_p=top_p, chunk_size=chunk_size, sample_tokens=sample_tokens)
            if sample_tokens is None:
                assert_shape(z, (N, *self.z_shape))
        return z, memory
Exemplo n.º 12
0
def calculate_bandwidth(dataset, hps, duration=600):
    hps = DefaultSTFTValues(hps)
    n_samples = int(dataset.sr * duration)
    l1, total, total_sq, n_seen, idx = 0.0, 0.0, 0.0, 0.0, dist.get_rank()
    spec_norm_total, spec_nelem = 0.0, 0.0
    while n_seen < n_samples:
        x = dataset[idx]
        if isinstance(x, (tuple, list)):
            x, y = x
        samples = x.astype(np.float64)
        stft = librosa.core.stft(np.mean(samples, axis=1),
                                 hps.n_fft,
                                 hop_length=hps.hop_length,
                                 win_length=hps.window_size)
        spec = np.absolute(stft)
        spec_norm_total += np.linalg.norm(spec)
        spec_nelem += 1
        n_seen += int(np.prod(samples.shape))
        l1 += np.sum(np.abs(samples))
        total += np.sum(samples)
        total_sq += np.sum(samples**2)
        idx += max(16, dist.get_world_size())

    if dist.is_available():
        from jukebox.utils.dist_utils import allreduce
        n_seen = allreduce(n_seen)
        total = allreduce(total)
        total_sq = allreduce(total_sq)
        l1 = allreduce(l1)
        spec_nelem = allreduce(spec_nelem)
        spec_norm_total = allreduce(spec_norm_total)

    mean = total / n_seen
    bandwidth = dict(l2=total_sq / n_seen - mean**2,
                     l1=l1 / n_seen,
                     spec=spec_norm_total / spec_nelem)
    print_once(bandwidth)
    return bandwidth
Exemplo n.º 13
0
def train(model, orig_model, opt, shd, scalar, ema, logger, metrics, data_processor, hps):
    model.train()
    orig_model.train()
    if hps.prior:
        _print_keys = dict(l="loss", bpd="bpd", gn="gn", g_l="gen_loss", p_l="prime_loss")
    else:
        _print_keys = dict(l="loss", sl="spectral_loss", rl="recons_loss", e="entropy", u="usage", uc="used_curr", gn="gn", pn="pn", dk="dk")

    print_all(data_processor.train_loader)
    print_all(len(data_processor.train_loader))

    for i, x in logger.get_range(data_processor.train_loader):
        if isinstance(x, (tuple, list)):
            x, y = x
        else:
            y = None

        x = x.to('cuda', non_blocking=True)
        if y is not None:
            y = y.to('cuda', non_blocking=True)

        x_in = x = audio_preprocess(x, hps)
        log_input_output = (logger.iters % hps.save_iters == 0)

        if hps.prior:
            forw_kwargs = dict(y=y, fp16=hps.fp16, decode=log_input_output)
        else:
            forw_kwargs = dict(loss_fn=hps.loss_fn, hps=hps)

        # Forward
        x_out, loss, _metrics = model(x, **forw_kwargs)

        # Backward
        loss, scale, grad_norm, overflow_loss, overflow_grad = backward(loss=loss, params=list(model.parameters()),
                                                                         scalar=scalar, fp16=hps.fp16, logger=logger)
        # Skip step if overflow
        grad_norm = allreduce(grad_norm, op=dist.ReduceOp.MAX)
        if overflow_loss or overflow_grad or grad_norm > hps.ignore_grad_norm > 0:
            zero_grad(orig_model)
            continue

        # Step opt. Divide by scale to include clipping and fp16 scaling
        logger.step()
        opt.step(scale=clipped_grad_scale(grad_norm, hps.clip, scale))
        zero_grad(orig_model)
        lr = hps.lr if shd is None else shd.get_lr()[0]
        if shd is not None: shd.step()
        if ema is not None: ema.step()
        next_lr = hps.lr if shd is None else shd.get_lr()[0]
        finished_training = (next_lr == 0.0)

        # Logging
        for key, val in _metrics.items():
            _metrics[key] = val.item()
        _metrics["loss"] = loss = loss.item() * hps.iters_before_update # Make sure to call to free graph
        _metrics["gn"] = grad_norm
        _metrics["lr"] = lr
        _metrics["lg_loss_scale"] = np.log2(scale)

        # Average and log
        for key, val in _metrics.items():
            _metrics[key] = metrics.update(key, val, x.shape[0])
            if logger.iters % hps.log_steps == 0:
                logger.add_scalar(key, _metrics[key])

        # Save checkpoint
        with t.no_grad():
            if hps.save and (logger.iters % hps.save_iters == 1 or finished_training):
                if ema is not None: ema.swap()
                orig_model.eval()
                name = 'latest' if hps.prior else f'step_{logger.iters}'
                if dist.get_rank() % 8 == 0:
                    save_checkpoint(logger, name, orig_model, opt, dict(step=logger.iters), hps)
                orig_model.train()
                if ema is not None: ema.swap()

        # Sample
        with t.no_grad():
            if (logger.iters % 12000) in list(range(1, 1 + hps.iters_before_update)) or finished_training:
                if hps.prior:
                    sample_prior(orig_model, ema, logger, x_in, y, hps)

        # Input/Output
        with t.no_grad():
            if log_input_output:
                log_inputs(orig_model, logger, x_in, y, x_out, hps)

        print("Hey there")
        logger.set_postfix(**{print_key:_metrics[key] for print_key, key in _print_keys.items()})
        print("by there")
        if finished_training:
            dist.barrier()
            exit()
    logger.close_range()
    return {key: metrics.avg(key) for key in _metrics.keys()}
Exemplo n.º 14
0
    def __init__(self,
                 z_shapes,
                 l_bins,
                 encoder,
                 decoder,
                 level,
                 downs_t,
                 strides_t,
                 labels,
                 prior_kwargs,
                 x_cond_kwargs,
                 y_cond_kwargs,
                 prime_kwargs,
                 copy_input,
                 labels_v3=False,
                 merged_decoder=False,
                 single_enc_dec=False):
        super().__init__()

        self.use_tokens = prime_kwargs.pop('use_tokens')
        self.n_tokens = prime_kwargs.pop('n_tokens')
        self.prime_loss_fraction = prime_kwargs.pop('prime_loss_fraction')

        self.copy_input = copy_input
        if self.copy_input:
            prime_kwargs['bins'] = l_bins

        self.z_shapes = z_shapes
        self.levels = len(self.z_shapes)

        self.z_shape = self.z_shapes[level]

        self.level = level
        assert level < self.levels, f"Total levels {self.levels}, got level {level}"

        self.l_bins = l_bins

        # Passing functions instead of the vqvae module to avoid getting params
        self.encoder = encoder
        self.decoder = decoder

        # X conditioning
        self.x_cond = (level != (self.levels - 1))
        self.cond_level = level + 1

        # Y conditioning
        self.y_cond = labels

        self.single_enc_dec = single_enc_dec
        # X conditioning
        if self.x_cond:
            self.conditioner_blocks = nn.ModuleList()
            conditioner_block = lambda _level: Conditioner(
                input_shape=z_shapes[_level],
                bins=l_bins,
                down_t=downs_t[_level],
                stride_t=strides_t[_level],
                **x_cond_kwargs)
            if dist.get_rank() == 0: print(f"Conditioning on 1 above level(s)")
            self.conditioner_blocks.append(conditioner_block(self.cond_level))

        # Y conditioning
        if self.y_cond:
            self.n_time = self.z_shape[
                0]  # Assuming STFT=TF order and raw=T1 order, so T is first dim
            self.y_emb = LabelConditioner(n_time=self.n_time,
                                          include_time_signal=not self.x_cond,
                                          **y_cond_kwargs)

        # Lyric conditioning
        if single_enc_dec:
            # Single encoder-decoder transformer
            self.prior_shapes = [(self.n_tokens, ),
                                 prior_kwargs.pop('input_shape')]
            self.prior_bins = [prime_kwargs['bins'], prior_kwargs.pop('bins')]
            self.prior_dims = [np.prod(shape) for shape in self.prior_shapes]
            self.prior_bins_shift = np.cumsum([0, *self.prior_bins])[:-1]
            self.prior_width = prior_kwargs['width']
            print_once(
                f'Creating cond. autoregress with prior bins {self.prior_bins}, '
            )
            print_once(f'dims {self.prior_dims}, ')
            print_once(f'shift {self.prior_bins_shift}')
            print_once(f'input shape {sum(self.prior_dims)}')
            print_once(f'input bins {sum(self.prior_bins)}')
            print_once(f'Self copy is {self.copy_input}')

            self.prime_loss_dims, self.gen_loss_dims = self.prior_dims[
                0], self.prior_dims[1]
            self.total_loss_dims = self.prime_loss_dims + self.gen_loss_dims
            self.prior = ConditionalAutoregressive2D(
                input_shape=(sum(self.prior_dims), ),
                bins=sum(self.prior_bins),
                x_cond=(self.x_cond or self.y_cond),
                y_cond=True,
                prime_len=self.prime_loss_dims,
                **prior_kwargs)

        else:
            # Separate encoder-decoder transformer
            if self.n_tokens != 0 and self.use_tokens:
                from jukebox.transformer.ops import Conv1D
                prime_input_shape = (self.n_tokens, )
                self.prime_loss_dims = np.prod(prime_input_shape)
                self.prime_acts_width, self.prime_state_width = prime_kwargs[
                    'width'], prior_kwargs['width']
                self.prime_prior = ConditionalAutoregressive2D(
                    input_shape=prime_input_shape,
                    x_cond=False,
                    y_cond=False,
                    only_encode=True,
                    **prime_kwargs)
                self.prime_state_proj = Conv1D(
                    self.prime_acts_width,
                    self.prime_state_width,
                    init_scale=prime_kwargs['init_scale'])
                self.prime_state_ln = LayerNorm(self.prime_state_width)
                self.prime_bins = prime_kwargs['bins']
                self.prime_x_out = nn.Linear(self.prime_state_width,
                                             self.prime_bins,
                                             bias=False)
                nn.init.normal_(self.prime_x_out.weight,
                                std=0.02 * prior_kwargs['init_scale'])
            else:
                self.prime_loss_dims = 0
            self.gen_loss_dims = np.prod(self.z_shape)
            self.total_loss_dims = self.prime_loss_dims + self.gen_loss_dims
            self.prior = ConditionalAutoregressive2D(
                x_cond=(self.x_cond or self.y_cond),
                y_cond=self.y_cond,
                encoder_dims=self.prime_loss_dims,
                merged_decoder=merged_decoder,
                **prior_kwargs)

        self.n_ctx = self.gen_loss_dims
        self.downsamples = calculate_strides(strides_t, downs_t)
        self.cond_downsample = self.downsamples[
            level + 1] if level != self.levels - 1 else None
        self.raw_to_tokens = np.prod(self.downsamples[:level + 1])
        self.sample_length = self.n_ctx * self.raw_to_tokens
        if labels:
            self.labels_v3 = labels_v3
            self.labeller = Labeller(self.y_emb.max_bow_genre_size,
                                     self.n_tokens,
                                     self.sample_length,
                                     v3=self.labels_v3)

        print(
            f"Level:{level}, Cond downsample:{self.cond_downsample}, Raw to tokens:{self.raw_to_tokens}, Sample length:{self.sample_length}"
        )
Exemplo n.º 15
0
def print_once(msg):
    if (not dist.is_available()) or dist.get_rank() == 0:
        print(msg)
Exemplo n.º 16
0
def print_all(msg):
    if (not dist.is_available()):
        print(msg)
    elif dist.get_rank() % 8 == 0:
        print(f'{dist.get_rank()//8}: {msg}')
Exemplo n.º 17
0
def get_range(x):
    if dist.get_rank() == 0:
        return def_tqdm(x)
    else:
        return x