Beispiel #1
0
def freeze_rng_state():
    rng_state = torch.get_rng_state()
    if torch.cuda.is_available():
        cuda_rng_state = torch.cuda.get_rng_state()
    yield
    if torch.cuda.is_available():
        torch.cuda.set_rng_state(cuda_rng_state)
    torch.set_rng_state(rng_state)
Beispiel #2
0
	def forward(self,seed):
		#Store rng
		rng_cpu=torch.get_rng_state();
		rng_gpu=torch.cuda.get_rng_state();
		torch.manual_seed(seed);
		torch.cuda.manual_seed(seed);
		mask=[];
		for param in self.params:
			mask.append(Variable(param.data.clone().bernoulli_(self.p)));
		#Recover rng
		torch.set_rng_state(rng_cpu);
		torch.cuda.set_rng_state(rng_gpu);
		#Compute output	
		out=[];
		for i,param in enumerate(self.params):
			out.append(mask[i]*param);
		return out;
Beispiel #3
0
	def noise(self,seed=None):
		if seed is None:
			eps=[];
			for param in self.mean:
				eps.append(Variable(param.data.clone().normal_(0,1)));
			return eps;
		else:
			rng_cpu=torch.get_rng_state();
			rng_gpu=torch.cuda.get_rng_state();
			torch.manual_seed(seed);
			torch.cuda.manual_seed(seed);
			#generate noise
			eps=[];
			for param in self.mean:
				eps.append(Variable(param.data.clone().normal_(0,1)));
			#Recover rng
			torch.set_rng_state(rng_cpu);
			torch.cuda.set_rng_state(rng_gpu);
			return eps;
Beispiel #4
0
def checkpoint(model, best_loss, epoch, LR):
    """
    Saves checkpoint of torchvision model during training.

    Args:
        model: torchvision model to be saved
        best_loss: best val loss achieved so far in training
        epoch: current epoch of training
        LR: current learning rate in training
    Returns:
        None
    """

    print('saving')
    state = {
        'model': model,
        'best_loss': best_loss,
        'epoch': epoch,
        'rng_state': torch.get_rng_state(),
        'LR': LR
    }

    torch.save(state, 'results/checkpoint')
Beispiel #5
0
    def __getitem__(self, index):
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is class_index of the target class.
        """
        # create random image that is consistent with the index id
        rng_state = torch.get_rng_state()
        torch.manual_seed(index + self.random_offset)
        img = torch.randn(*self.image_size)
        target = torch.Tensor(1).random_(0, self.num_classes)[0]
        torch.set_rng_state(rng_state)

        # convert to PIL Image
        img = transforms.ToPILImage()(img)
        if self.transform is not None:
            img = self.transform(img)
        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target
Beispiel #6
0
def fork_rng(devices=None, enabled=True, _caller="fork_rng", _devices_kw="devices"):
    """
    Forks the RNG, so that when you return, the RNG is reset
    to the state that it was previously in.

    Arguments:
        devices (iterable of CUDA IDs): CUDA devices for which to fork
            the RNG.  CPU RNG state is always forked.  By default, fork_rng operates
            on all devices, but will emit a warning if your machine has a lot
            of devices, since this function will run very slowly in that case.
            If you explicitly specify devices, this warning will be supressed
        enabled (bool): if ``False``, the RNG is not forked.  This is a convenience
            argument for easily disabling the context manager without having
            to reindent your Python code.
    """

    import torch.cuda
    global _fork_rng_warned_already

    # Internal arguments:
    #   _caller: the function which called fork_rng, which the user used
    #   _devices_kw: the devices keyword of _caller

    if not enabled:
        yield
        return

    if devices is None:
        num_devices = torch.cuda.device_count()
        if num_devices > 1 and not _fork_rng_warned_already:
            warnings.warn(
                ("CUDA reports that you have {num_devices} available devices, and you "
                 "have used {caller} without explicitly specifying which devices are being used. "
                 "For safety, we initialize *every* CUDA device by default, which "
                 "can be quite slow if you have a lot of GPUs.  If you know that you are only "
                 "making use of a few CUDA devices, set the environment variable CUDA_VISIBLE_DEVICES "
                 "or the '{devices_kw}' keyword argument of {caller} with the set of devices "
                 "you are actually using.  For example, if you are using CPU only, "
                 "set CUDA_VISIBLE_DEVICES= or devices=[]; if you are using "
                 "GPU 0 only, set CUDA_VISIBLE_DEVICES=0 or devices=[0].  To initialize "
                 "all devices and suppress this warning, set the '{devices_kw}' keyword argument "
                 "to `range(torch.cuda.device_count())`."
                 ).format(num_devices=num_devices, caller=_caller, devices_kw=_devices_kw))
            _fork_rng_warned_already = True
        devices = list(range(num_devices))
    else:
        # Protect against user passing us a generator; we need to traverse this
        # multiple times but a generator will be exhausted upon first traversal
        devices = list(devices)

    cpu_rng_state = torch.get_rng_state()
    gpu_rng_states = []
    for device in devices:
        with torch.cuda.device(device):
            gpu_rng_states.append(torch.cuda.get_rng_state())

    try:
        yield
    finally:
        torch.set_rng_state(cpu_rng_state)
        for device, gpu_rng_state in zip(devices, gpu_rng_states):
            with torch.cuda.device(device):
                torch.cuda.set_rng_state(gpu_rng_state)
    def backward(ctx, *grads):
        global timers
        see_memory_usage("In backward", force=False)
        # removing pointers to the contiguous buffer memory
        # so that they can be garbage collected once the checkpoints
        # have been used
        if SYNCHRONIZE:
            torch.cuda.synchronize()
        if PROFILE_TIME:
            timers('backward').start()

        if CONTIGUOUS_CHECKPOINTING:
            global data_offsets, size_offsets
            global contiguous_data_buffers, contiguous_size_buffers

            for buffers in contiguous_data_buffers:
                buffers = []

            # frees up all the pointers to the checkpoints except for the ones
            # stored by save for backward
            contiguous_data_buffers = []
            contiguous_size_buffers = []
            data_offsets = []
            size_offsets = []

        see_memory_usage("In backward checkpointing code", force=False)
        if not torch.autograd._is_checkpoint_valid():
            raise RuntimeError("Checkpointing is not compatible with .grad(), "
                               "please use .backward() if possible")

        global cuda_device, transport_stream, PARTITION_ACTIVATIONS

        if PARTITION_ACTIVATIONS:
            if ctx.saved_tensors and ctx.saved_tensors[0].dtype == torch.bfloat16:
                FP32_COMM = True
            else:
                FP32_COMM = False
            # with torch.cuda.stream(transport_stream):
            inputs = get_full_inputs(ctx.saved_tensors,
                                     device=cuda_device if PA_TO_CPU else None,
                                     fp32_comm=FP32_COMM)
            detached_inputs = detach_variable(inputs)
        else:
            inputs = ctx.saved_tensors
            detached_inputs = detach_variable(inputs)

        # Add non tensor input args
        detached_inputs = merge_tensors(tensor_objects=detached_inputs,
                                        non_tensor_objects=ctx.non_tensor_args,
                                        tensor_flags=ctx.tensor_flags)

        # Store the current states.
        bwd_cpu_rng_state = torch.get_rng_state()
        bwd_cuda_rng_state = torch.cuda.get_rng_state()
        bwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()

        # Set the states to what it used to be before the forward pass.
        torch.set_rng_state(ctx.fwd_cpu_rng_state)
        _set_cuda_rng_state(ctx.fwd_cuda_rng_state)
        get_cuda_rng_tracker().set_states(ctx.fwd_cuda_rng_state_tracker)

        # if PARTITION_ACTIVATIONS:
        #     current_stream=torch.cuda.current_stream()
        #     current_stream.wait_stream(transport_stream)

        see_memory_usage("In backward checkpointing code before forward", force=False)

        with torch.enable_grad():
            outputs = ctx.run_function(*detached_inputs)

        see_memory_usage("In backward checkpointing code after forward", force=False)
        # Set the states back to what it was at the start of this function.
        torch.set_rng_state(bwd_cpu_rng_state)
        _set_cuda_rng_state(bwd_cuda_rng_state)
        get_cuda_rng_tracker().set_states(bwd_cuda_rng_state_tracker)

        if isinstance(outputs, torch.Tensor):
            outputs = (outputs, )

        # Filter out non tensor outputs
        outputs, _, _ = extract_tensors(all_objects=outputs)

        # Construct arguments to autograd.backward().
        # This is usually just outputs and grads, but forward() can return tensors that
        # are not differentiable.
        output_tensors = []
        grad_tensors = []
        for out, grad in zip(outputs, grads):
            if out.requires_grad:
                output_tensors.append(out)
                grad_tensors.append(grad)

        see_memory_usage("In backward checkpointing code before backward", force=False)

        torch.autograd.backward(output_tensors, grad_tensors)

        see_memory_usage("After backward checkpointing code after backward", force=False)

        if PROFILE_TIME:
            timers('backward').stop()
            timers.log(['backward'])
        if SYNCHRONIZE:
            torch.cuda.synchronize()
        ret_list = [None, None]  # first None for ctx
        for inp in detached_inputs:
            if torch.is_tensor(inp):
                ret_list.append(inp.grad)
            else:
                ret_list.append(None)

        return tuple(ret_list)
Beispiel #8
0
def train_CcGAN(kernel_sigma,
                kappa,
                train_images,
                train_labels,
                netG,
                netD,
                save_images_folder,
                save_models_folder=None,
                clip_label=False):
    '''
    Note that train_images are not normalized to [-1,1]
    '''

    netG = netG.to(device)
    netD = netD.to(device)

    optimizerG = torch.optim.Adam(netG.parameters(),
                                  lr=lr_g,
                                  betas=(0.5, 0.999))
    optimizerD = torch.optim.Adam(netD.parameters(),
                                  lr=lr_d,
                                  betas=(0.5, 0.999))

    if save_models_folder is not None and resume_niters > 0:
        save_file = save_models_folder + "/CcGAN_{}_checkpoint_intrain/CcGAN_checkpoint_niters_{}.pth".format(
            threshold_type, resume_niters)
        checkpoint = torch.load(save_file)
        netG.load_state_dict(checkpoint['netG_state_dict'])
        netD.load_state_dict(checkpoint['netD_state_dict'])
        optimizerG.load_state_dict(checkpoint['optimizerG_state_dict'])
        optimizerD.load_state_dict(checkpoint['optimizerD_state_dict'])
        torch.set_rng_state(checkpoint['rng_state'])
    #end if

    #################
    unique_train_labels = np.sort(np.array(list(set(train_labels))))

    # printed images with labels between the 5-th quantile and 95-th quantile of training labels
    n_row = 10
    n_col = n_row
    z_fixed = torch.randn(n_row * n_col, dim_gan, dtype=torch.float).to(device)
    start_label = np.quantile(train_labels, 0.05)
    end_label = np.quantile(train_labels, 0.95)
    selected_labels = np.linspace(start_label, end_label, num=n_row)
    y_fixed = np.zeros(n_row * n_col)
    for i in range(n_row):
        curr_label = selected_labels[i]
        for j in range(n_col):
            y_fixed[i * n_col + j] = curr_label
    print(y_fixed)
    y_fixed = torch.from_numpy(y_fixed).type(torch.float).view(-1,
                                                               1).to(device)

    start_time = timeit.default_timer()
    for niter in range(resume_niters, niters):
        '''  Train Discriminator   '''
        ## randomly draw batch_size_disc y's from unique_train_labels
        batch_target_labels_in_dataset = np.random.choice(unique_train_labels,
                                                          size=batch_size_max,
                                                          replace=True)
        ## add Gaussian noise; we estimate image distribution conditional on these labels
        batch_epsilons = np.random.normal(0, kernel_sigma, batch_size_max)
        batch_target_labels_with_epsilon = batch_target_labels_in_dataset + batch_epsilons
        if clip_label:
            batch_target_labels_with_epsilon = np.clip(
                batch_target_labels_with_epsilon, 0.0, 1.0)

        batch_target_labels = batch_target_labels_with_epsilon[
            0:batch_size_disc]

        ## find index of real images with labels in the vicinity of batch_target_labels
        ## generate labels for fake image generation; these labels are also in the vicinity of batch_target_labels
        batch_real_indx = np.zeros(
            batch_size_disc, dtype=int
        )  #index of images in the datata; the labels of these images are in the vicinity
        batch_fake_labels = np.zeros(batch_size_disc)

        for j in range(batch_size_disc):
            ## index for real images
            if threshold_type == "hard":
                indx_real_in_vicinity = np.where(
                    np.abs(train_labels - batch_target_labels[j]) <= kappa)[0]
            else:
                # reverse the weight function for SVDL
                indx_real_in_vicinity = np.where(
                    (train_labels - batch_target_labels[j]
                     )**2 <= -np.log(nonzero_soft_weight_threshold) / kappa)[0]

            ## if the max gap between two consecutive ordered unique labels is large, it is possible that len(indx_real_in_vicinity)<1
            while len(indx_real_in_vicinity) < 1:
                batch_epsilons_j = np.random.normal(0, kernel_sigma, 1)
                batch_target_labels[
                    j] = batch_target_labels_in_dataset[j] + batch_epsilons_j
                if clip_label:
                    batch_target_labels = np.clip(batch_target_labels, 0.0,
                                                  1.0)
                ## index for real images
                if threshold_type == "hard":
                    indx_real_in_vicinity = np.where(
                        np.abs(train_labels -
                               batch_target_labels[j]) <= kappa)[0]
                else:
                    # reverse the weight function for SVDL
                    indx_real_in_vicinity = np.where(
                        (train_labels - batch_target_labels[j])**2 <=
                        -np.log(nonzero_soft_weight_threshold) / kappa)[0]
            #end while len(indx_real_in_vicinity)<1

            assert len(indx_real_in_vicinity) >= 1

            batch_real_indx[j] = np.random.choice(indx_real_in_vicinity,
                                                  size=1)[0]

            ## labels for fake images generation
            if threshold_type == "hard":
                lb = batch_target_labels[j] - kappa
                ub = batch_target_labels[j] + kappa
            else:
                lb = batch_target_labels[j] - np.sqrt(
                    -np.log(nonzero_soft_weight_threshold) / kappa)
                ub = batch_target_labels[j] + np.sqrt(
                    -np.log(nonzero_soft_weight_threshold) / kappa)
            lb = max(0.0, lb)
            ub = min(ub, 1.0)
            assert lb <= ub
            assert lb >= 0 and ub >= 0
            assert lb <= 1 and ub <= 1
            batch_fake_labels[j] = np.random.uniform(lb, ub, size=1)[0]
        #end for j

        ## draw the real image batch from the training set
        batch_real_images = train_images[batch_real_indx]
        assert batch_real_images.max() > 1
        batch_real_labels = train_labels[batch_real_indx]
        batch_real_labels = torch.from_numpy(batch_real_labels).type(
            torch.float).to(device)

        ## normalize real images to [-1,1]
        trainset = IMGs_dataset(batch_real_images, labels=None, normalize=True)
        train_dataloader = torch.utils.data.DataLoader(
            trainset, batch_size=batch_size_disc, shuffle=False)
        train_dataloader = iter(train_dataloader)
        batch_real_images = train_dataloader.next()
        assert len(batch_real_images) == batch_size_disc
        batch_real_images = batch_real_images.type(torch.float).to(device)
        assert batch_real_images.max().item() <= 1

        ## generate the fake image batch
        batch_fake_labels = torch.from_numpy(batch_fake_labels).type(
            torch.float).to(device)
        z = torch.randn(batch_size_disc, dim_gan, dtype=torch.float).to(device)
        batch_fake_images = netG(z, batch_fake_labels)

        ## target labels on gpu
        batch_target_labels = torch.from_numpy(batch_target_labels).type(
            torch.float).to(device)

        ## weight vector
        if threshold_type == "soft":
            real_weights = torch.exp(
                -kappa *
                (batch_real_labels - batch_target_labels)**2).to(device)
            fake_weights = torch.exp(
                -kappa *
                (batch_fake_labels - batch_target_labels)**2).to(device)
        else:
            real_weights = torch.ones(batch_size_disc,
                                      dtype=torch.float).to(device)
            fake_weights = torch.ones(batch_size_disc,
                                      dtype=torch.float).to(device)
        #end if threshold type

        # forward pass
        real_dis_out = netD(batch_real_images, batch_target_labels)
        fake_dis_out = netD(batch_fake_images.detach(), batch_target_labels)

        d_loss = -torch.mean(
            real_weights.view(-1) *
            torch.log(real_dis_out.view(-1) + 1e-20)) - torch.mean(
                fake_weights.view(-1) *
                torch.log(1 - fake_dis_out.view(-1) + 1e-20))

        optimizerD.zero_grad()
        d_loss.backward()
        optimizerD.step()
        '''  Train Generator   '''
        netG.train()

        # generate fake images
        batch_target_labels = batch_target_labels_with_epsilon[
            0:batch_size_gene]
        batch_target_labels = torch.from_numpy(batch_target_labels).type(
            torch.float).to(device)

        z = torch.randn(batch_size_gene, dim_gan, dtype=torch.float).to(device)
        batch_fake_images = netG(z, batch_target_labels)

        # loss
        dis_out = netD(batch_fake_images, batch_target_labels)
        g_loss = -torch.mean(torch.log(dis_out + 1e-20))

        # backward
        optimizerG.zero_grad()
        g_loss.backward()
        optimizerG.step()

        # print loss
        if (niter + 1) % 20 == 0:
            print(
                "CcGAN: [Iter %d/%d] [D loss: %.4e] [G loss: %.4e] [real prob: %.3f] [fake prob: %.3f] [Time: %.4f]"
                % (niter + 1, niters, d_loss.item(), g_loss.item(),
                   real_dis_out.mean().item(), fake_dis_out.mean().item(),
                   timeit.default_timer() - start_time))

        if (niter + 1) % 100 == 0:
            netG.eval()
            with torch.no_grad():
                gen_imgs = netG(z_fixed, y_fixed)
                gen_imgs = gen_imgs.detach().cpu()
                save_image(gen_imgs.data,
                           save_images_folder + '/{}.png'.format(niter + 1),
                           nrow=n_row,
                           normalize=True)

        if save_models_folder is not None and (
            (niter + 1) % save_niters_freq == 0 or (niter + 1) == niters):
            save_file = save_models_folder + "/CcGAN_{}_checkpoint_intrain/CcGAN_checkpoint_niters_{}.pth".format(
                threshold_type, niter + 1)
            os.makedirs(os.path.dirname(save_file), exist_ok=True)
            torch.save(
                {
                    'netG_state_dict': netG.state_dict(),
                    'netD_state_dict': netD.state_dict(),
                    'optimizerG_state_dict': optimizerG.state_dict(),
                    'optimizerD_state_dict': optimizerD.state_dict(),
                    'rng_state': torch.get_rng_state()
                }, save_file)
    #end for niter
    return netG, netD
Beispiel #9
0
def get_rng_state():
    state = {"torch_rng_state": torch.get_rng_state()}
    if torch.cuda.is_available():
        state["cuda_rng_state"] = torch.cuda.get_rng_state()
    return state
Beispiel #10
0
            loss = criterion(outputs[:, 0],
                             labels).cpu().detach().numpy()  # calculate loss
            print(f'Validation loss = {loss.item()}')

            valLoss += loss

        allValLoss[epoch] = valLoss / len(val_dataloader)

        np.savetxt(out_path + '/allValLoss.txt', allValLoss)
        np.savetxt(out_path + '/allTrainLoss.txt', allTrainLoss)

        if allValLoss[epoch] < bestValLoss:
            print(
                f'Best seen validation performance ({bestValLoss} -> {allValLoss[epoch]}), saving...'
            )
            torch.save(model_2.state_dict(), out_path + bestValLossNetFileName)
            np.savetxt(out_path + '/bestEpochNum.txt', np.array([epoch]))
            bestValLoss = allValLoss[epoch]

    # checkpointing at the end of every epoch
    torch.save(model_2.state_dict(), out_path + currModelFilename)
    np.savetxt(f'{out_path}lastCompletedEpoch.txt', np.asarray([epoch]))
    np.savetxt(f'{out_path}randomState.txt', torch.get_rng_state().numpy())

    model_2 = model_2.train()

    print(f'Epoch = {epoch} finished')
print('Finished Training')
end = time.time()
print(f'Training took {end-start} seconds')
Beispiel #11
0
 def freeze(self):
     if 'torch' in sys.modules:
         import torch
         return torch.get_rng_state()
     else:
         return None
Beispiel #12
0
def get_rng_state():
    state = {'rng_state': torch.get_rng_state()}
    if torch.cuda.is_available():
        state['cuda_rng_state'] = torch.cuda.get_rng_state()
    return state
Beispiel #13
0
    def forward(ctx, run_function, *args):
        global mpu, timers, SYNCHRONIZE, PROFILE_TIME

        if SYNCHRONIZE:
            torch.cuda.synchronize()

        if timers is None and PROFILE_TIME:
            timers = Timers()

        if PROFILE_TIME:
            timers('forward').start()

        ctx.run_function = run_function
        global num_layers
        global mp_rank, mp_size, mp_group
        global contiguous_data_buffers, contiguous_size_buffers
        global data_offsets, size_offsets
        if mp_rank is None:
            if mpu is not None:
                mp_rank = mpu.get_model_parallel_rank()
                mp_size = mpu.get_model_parallel_world_size()
                mp_group = mpu.get_model_parallel_group()
            else:
                mp_rank = 0
                mp_size = 1
                mp_group = None

        global cuda_device, transport_stream, PARTITION_ACTIVATIONS, buffer_0, buffer_1, buffer_0_offset, buffer_1_offset

        if cuda_device is None:
            see_memory_usage("First Forward Begining", force=True)
            if dist.get_rank() == 0:
                logger.info(f"Activation Checkpointing Information")
                logger.info(
                    f"----Partition Activations {PARTITION_ACTIVATIONS}, CPU CHECKPOINTING {PA_TO_CPU}"
                )
                logger.info(
                    f"----contiguous Memory Checkpointing {CONTIGUOUS_CHECKPOINTING} with {num_layers} total layers"
                )
                logger.info(f"----Synchronization {SYNCHRONIZE}")
                logger.info(f"----Profiling {PROFILE_TIME}")

            cuda_device = torch.cuda.current_device()
            transport_stream = torch.cuda.Stream(device=cuda_device)

        if PARTITION_ACTIVATIONS:
            #inputs = [item.detach().contiguous().view(-1).narrow(0, get_partition_start(item), get_partition_size(item)).clone() for item in args[:-1]]
            #inputs.append(args[-1])

            inputs = []
            for i, item in enumerate(args[:-1]):
                partition_size = get_partition_size(item)
                partition = item.detach().contiguous().view(-1).narrow(
                    0, get_partition_start(item), partition_size).clone()

                if CONTIGUOUS_CHECKPOINTING:
                    buffer_device = torch.device(
                        'cpu') if PA_TO_CPU else partition.device

                    if i >= len(contiguous_data_buffers):
                        tensor_list = [
                            torch.tensor(()).new_empty([partition_size],
                                                       dtype=partition.dtype,
                                                       device=buffer_device)
                            for i in range(num_layers)
                        ]
                        contiguous_data_buffers.append(tensor_list)
                        data_offsets.append(0)
                    elif contiguous_data_buffers[i] is None:
                        tensor_list = [
                            torch.tensor(()).new_empty([partition_size],
                                                       dtype=partition.dtype,
                                                       device=buffer_device)
                            for i in range(num_layers)
                        ]
                        contiguous_data_buffers[i] = tensor_list
                        data_offsets[i] = 0

                    contiguous_partition = contiguous_data_buffers[i][
                        data_offsets[i]].data.copy_(partition.data)
                    data_offsets[i] = data_offsets[i] + 1
                    inputs.append(contiguous_partition)
                else:
                    partition = partition.cpu() if PA_TO_CPU else partition
                    inputs.append(partition)

            inputs.append(args[-1])

        #just in case something funky is happening such as reuse of inputs
        inputs_cuda = [item.to(cuda_device) for item in args]

        # Copy the rng states.
        ctx.fwd_cpu_rng_state = torch.get_rng_state()
        ctx.fwd_cuda_rng_state = torch.cuda.get_rng_state()
        ctx.fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()

        #ctx.save_for_backward(*args)
        with torch.no_grad():
            outputs = run_function(*inputs_cuda)

        del inputs_cuda

        #with torch.cuda.stream(transport_stream):
        #if PARTITION_ACTIVATIONS:
        #    new_args = []
        #    for arg, inp in zip(args,inputs):
        #        size= torch.tensor(arg.size())
        #        arg.data = inp.data
        #        new_args.append(arg)
        #        new_args.append(size)
        #    ctx.save_for_backward(*new_args)

        if PARTITION_ACTIVATIONS:
            new_args = []
            for i, (arg, inp) in enumerate(zip(args, inputs)):
                size = torch.tensor(arg.size())

                arg.data = inp.data
                new_args.append(arg)

                if CONTIGUOUS_CHECKPOINTING:
                    numel = size.numel()
                    if i >= len(contiguous_size_buffers):
                        tmp = torch.tensor(())
                        contiguous_size_buffers.append(
                            tmp.new_empty([numel * num_layers],
                                          dtype=size.dtype,
                                          device=size.device))
                        size_offsets.append(0)
                    elif contiguous_size_buffers[i] is None:
                        tmp = torch.tensor(())
                        contiguous_size_buffers[i] = tmp.new_empty(
                            [numel * num_layers],
                            dtype=size.dtype,
                            device=size.device)
                        size_offsets[i] = 0

                    contiguous_size = contiguous_size_buffers[i].narrow(
                        0, size_offsets[i], numel).data.copy_(size.data)
                    contiguous_size = contiguous_size.view_as(size)
                    size_offsets[i] = size_offsets[i] + numel
                    new_args.append(contiguous_size)
                else:
                    new_args.append(size)
                #if dist.get_rank() == 0:
                #    logger.info(f"The stored tensor is {contiguous_size} and orginal one is {size} ")

            ctx.save_for_backward(*new_args)
        else:
            ctx.save_for_backward(*args)
        if PROFILE_TIME:
            timers('forward').stop()
            timers.log(['forward'])
        if SYNCHRONIZE:
            torch.cuda.synchronize()
        return outputs
Beispiel #14
0
def train(total_iters=0, skipped_iters=0, elapsed_time=False):
    # Turn on training mode which enables dropout.
    model.train()
    total_loss = 0
    start_time = time.time()
    t0 = start_time
    ntokens = args.data_size
    hidden = init_hidden(args.batch_size)
    curr_loss = 0.
    distributed = isinstance(model, DDP)
    max_iters = len(train_data)

    def log(epoch, i, lr, ms_iter, total_time, loss, scale):
        print(
            '| epoch {:3d} | {:5d}/{:5d} batches | lr {:.2E} | ms/batch {:.3E} | total time {:.3E}\
                  loss {:.2E} | ppl {:8.2f} | loss scale {:8.2f}'.format(
                epoch, i, max_iters, lr, ms_iter, total_time, loss,
                math.exp(min(loss, 20)), scale))

    for i, batch in enumerate(train_data):
        data, targets, reset_mask = get_batch(batch)
        optim.zero_grad()
        output, hidden = model(data, reset_mask=reset_mask)
        loss = criterion(
            output.view(-1, ntokens).contiguous().float(),
            targets.view(-1).contiguous())
        total_loss += loss.data.float()

        if args.fp16:
            optim.backward(loss, update_master_grads=False)
        else:
            loss.backward()

        if distributed:
            torch.distributed.all_reduce(loss.data)
            loss.data /= args.world_size
            model.allreduce_params()

        # clipping gradients helps prevent the exploding gradient problem in RNNs / LSTMs.
        if args.clip > 0:
            if not args.fp16:
                torch.nn.utils.clip_grad_norm(model.parameters(), args.clip)
            else:
                optim.clip_master_grads(clip=args.clip)

        if args.fp16:
            optim.update_master_grads()

        optim.step()

        # step learning rate and log training progress
        lr = optim.param_groups[0]['lr']
        if not args.fp16:
            LR.step()
        else:
            # if fp16 optimizer skips gradient step due to explosion do not step lr
            if not optim.overflow:
                LR.step()
            else:
                skipped_iters += 1

        # log current results
        if ((i + 1) % args.log_interval == 0) and (i != max_iters - 1):
            cur_loss = total_loss[0] / args.log_interval
            cur_time = time.time()
            elapsed = cur_time - start_time
            total_elapsed = cur_time - t0 + elapsed_time
            log(epoch, i + 1, lr, elapsed * 1000 / args.log_interval,
                total_elapsed, cur_loss,
                args.loss_scale if not args.fp16 else optim.loss_scale)
            total_loss = 0
            start_time = cur_time
            sys.stdout.flush()

        # save current model progress. If distributed only save from worker 0
        if args.save_iters and total_iters % args.save_iters == 0 and total_iters > 0 and args.rank < 1:
            if args.rank < 1:
                with open(
                        os.path.join(
                            os.path.splitext(args.save)[0],
                            'e%s.pt' % (str(total_iters), )), 'wb') as f:
                    torch.save(model.state_dict(), f)
                if args.save_optim:
                    with open(
                            os.path.join(
                                os.path.splitext(args.save)[0], 'optim.pt'),
                            'wb') as f:
                        optim_sd = optim.state_dict()
                        optim_sd['iter'] = total_iters
                        optim_sd['skipped_iter'] = skipped_iters
                        torch.save(optim_sd, f)
                        del optim_sd

                    with open(
                            os.path.join(
                                os.path.splitext(args.save)[0], 'rng.pt'),
                            'wb') as f:
                        torch.save((torch.cuda.get_rng_state(),
                                    torch.get_rng_state()), f)
            if args.cuda:
                torch.cuda.synchronize()
        total_iters += 1
    #final logging
    elapsed_iters = max_iters % args.log_interval
    if elapsed_iters == 0:
        elapsed_iters = args.log_interval
    cur_loss = total_loss[0] / elapsed_iters
    cur_time = time.time()
    elapsed = cur_time - start_time
    total_elapsed = cur_time - t0 + elapsed_time
    log(epoch, max_iters, lr, elapsed * 1000 / elapsed_iters, total_elapsed,
        cur_loss, args.loss_scale if not args.fp16 else optim.loss_scale)

    return cur_loss, skipped_iters
Beispiel #15
0
def train_AE():

    # define optimizer
    params = list(net_encoder.parameters()) + list(net_decoder.parameters())
    optimizer = torch.optim.Adam(params,
                                 lr=base_lr,
                                 betas=(0.5, 0.999),
                                 weight_decay=1e-4)

    # criterion
    criterion = nn.MSELoss()

    if resume_epoch > 0:
        print("Loading ckpt to resume training AE >>>")
        ckpt_fullpath = save_models_folder + "/AE_checkpoint_intrain/AE_checkpoint_epoch_{}_lambda_{}.pth".format(
            resume_epoch, lambda_sparsity)
        checkpoint = torch.load(ckpt_fullpath)
        net_encoder.load_state_dict(checkpoint['net_encoder_state_dict'])
        net_decoder.load_state_dict(checkpoint['net_decoder_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        torch.set_rng_state(checkpoint['rng_state'])
        gen_iterations = checkpoint['gen_iterations']
    else:
        gen_iterations = 0

    start_time = timeit.default_timer()
    for epoch in range(resume_epoch, epochs):

        adjust_learning_rate(epoch, epochs, optimizer, base_lr,
                             lr_decay_epochs, lr_decay_factor)

        train_loss = 0

        for batch_idx, batch_real_images in enumerate(trainloader):

            net_encoder.train()
            net_decoder.train()

            batch_size_curr = batch_real_images.shape[0]

            batch_real_images = batch_real_images.type(torch.float).cuda()

            batch_features = net_encoder(batch_real_images)
            batch_recons_images = net_decoder(batch_features)
            '''
            based on https://debuggercafe.com/sparse-autoencoders-using-l1-regularization-with-pytorch/
            '''
            loss = criterion(
                batch_recons_images,
                batch_real_images) + lambda_sparsity * batch_features.mean()

            #backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_loss += loss.cpu().item()

            gen_iterations += 1

            if gen_iterations % 100 == 0:
                n_row = min(10, int(np.sqrt(batch_size_curr)))
                with torch.no_grad():
                    batch_recons_images = net_decoder(
                        net_encoder(batch_real_images[0:n_row**2]))
                    batch_recons_images = batch_recons_images.detach().cpu()
                save_image(batch_recons_images.data,
                           save_AE_images_in_train_folder +
                           '/{}.png'.format(gen_iterations),
                           nrow=n_row,
                           normalize=True)

            if gen_iterations % 20 == 0:
                print(
                    "AE+lambda{}: [step {}] [epoch {}/{}] [train loss {}] [Time {}]"
                    .format(lambda_sparsity, gen_iterations, epoch + 1, epochs,
                            train_loss / (batch_idx + 1),
                            timeit.default_timer() - start_time))
        # end for batch_idx

        if (epoch + 1) % 50 == 0:
            save_file = save_models_folder + "/AE_checkpoint_intrain/AE_checkpoint_epoch_{}_lambda_{}.pth".format(
                epoch + 1, lambda_sparsity)
            os.makedirs(os.path.dirname(save_file), exist_ok=True)
            torch.save(
                {
                    'gen_iterations': gen_iterations,
                    'net_encoder_state_dict': net_encoder.state_dict(),
                    'net_decoder_state_dict': net_decoder.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'rng_state': torch.get_rng_state()
                }, save_file)
    #end for epoch

    return net_encoder, net_decoder
Beispiel #16
0
def train_model(model, datasets, criterion, optimizer):

    global args, writer
    since = time.time()
    curr_loss = 0
    lr = args.learning_rate
    flag = False
    start_itr = 0
    num_running_batch = 0
    running_batch = {
        'previmg': torch.Tensor(batchSize, 3, input_size, input_size),
        'currimg': torch.Tensor(batchSize, 3, input_size, input_size),
        'previmg_x2': torch.Tensor(batchSize, 3, input_size * 2,
                                   input_size * 2),
        'currimg_x2': torch.Tensor(batchSize, 3, input_size * 2,
                                   input_size * 2),
        'currbb': torch.Tensor(batchSize, 4)
    }
    scheduler = optim.lr_scheduler.StepLR(optimizer,
                                          step_size=args.lr_decay_step,
                                          gamma=args.gamma)

    # resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            start_itr = checkpoint['itr']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            scheduler.load_state_dict(checkpoint['scheduler'])
            num_running_batch = checkpoint['num_running_batch']
            running_batch = checkpoint['running_batch']
            lr = checkpoint['lr']
            np.random.set_state(checkpoint['np_rand_state'])
            torch.set_rng_state(checkpoint['torch_rand_state'])
            print("=> loaded checkpoint '{}' (iteration {})".format(
                args.resume, checkpoint['itr']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    if not os.path.isdir(args.save_directory):
        os.makedirs(args.save_directory)

    itr = start_itr
    st = time.time()
    while itr < args.num_batches:

        model.train()
        if (args.resume and os.path.isfile(args.resume) and itr == start_itr
                and (not flag)):
            checkpoint = torch.load(args.resume)
            i = checkpoint['dataset_indx']
            flag = True
        else:
            i = 0

        # train on datasets
        # usually ALOV and ImageNet
        while i < len(datasets):
            dataset = datasets[i]
            i = i + 1
            (running_batch, train_batch, done,
             num_running_batch) = get_training_batch(num_running_batch,
                                                     running_batch, dataset)
            # print(i, num_running_batch, done)
            if done:
                scheduler.step()
                # load sample
                x1 = train_batch['previmg'].to(device)
                x2 = train_batch['currimg'].to(device)
                x1_x2 = train_batch['previmg_x2'].to(device)
                x2_x2 = train_batch['currimg_x2'].to(device)
                y = train_batch['currbb'].requires_grad_(False).to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                output = model(x1, x2, x1_x2, x2_x2)
                loss = criterion(output, y)

                # backward + optimize
                loss.backward()
                optimizer.step()

                # statistics
                curr_loss = loss.item()
                end = time.time()
                itr = itr + 1
                print('[training] step = %d/%d, loss = %f, time = %f' %
                      (itr, args.num_batches, curr_loss, end - st))
                sys.stdout.flush()
                del (train_batch)
                st = time.time()

                if enable_tensorboard:
                    writer.add_scalar('train/batch_loss', curr_loss, itr)

                if itr > 0 and itr % kSaveModel == 0:
                    path = os.path.join(
                        args.save_directory, 'model_itr_' + str(itr) +
                        '_loss_' + str(round(curr_loss, 3)) + '.pth.tar')
                    save_checkpoint(
                        {
                            'itr': itr,
                            'np_rand_state': np.random.get_state(),
                            'torch_rand_state': torch.get_rng_state(),
                            'l1_loss': curr_loss,
                            'state_dict': model.state_dict(),
                            'optimizer': optimizer.state_dict(),
                            'scheduler': scheduler.state_dict(),
                            'num_running_batch': num_running_batch,
                            'running_batch': running_batch,
                            'lr': lr,
                            'dataset_indx': i
                        }, path)

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    if enable_tensorboard:
        writer.export_scalars_to_json("./all_scalars.json")
        writer.close()
    return model
Beispiel #17
0
    def _n_features(self) -> int:
        """
        Calculates the number of extracted features going into the the classifier part.
        :return: Number of features.
        """
        def conv_size(h_in, w_in):
            if "kernel_size" in self.conv_params.keys():
                kernel_size = (self.conv_params['kernel_size'],
                               self.conv_params['kernel_size']) if isinstance(
                                   self.conv_params['kernel_size'],
                                   int) else self.conv_params['kernel_size']
            else:
                kernel_size = (self.kernel_size,
                               self.kernel_size) if isinstance(
                                   self.kernel_size, int) else self.kernel_size

            if "padding" in self.conv_params.keys():
                padding = (self.conv_params['padding'],
                           self.conv_params['padding']) if isinstance(
                               self.conv_params['padding'],
                               int) else self.conv_params['padding']
            else:
                padding = (0, 0)

            if "dilation" in self.conv_params.keys():
                dilation = (self.conv_params['dilation'],
                            self.conv_params['dilation']) if isinstance(
                                self.conv_params['dilation'],
                                int) else self.conv_params['dilation']
            else:
                dilation = (1, 1)

            if "stride" in self.conv_params.keys():
                stride = (self.conv_params['stride'],
                          self.conv_params['stride']) if isinstance(
                              self.conv_params['stride'],
                              int) else self.conv_params['stride']
            else:
                stride = (1, 1)

            # print("========convo===========")
            # print("kernel = " + str(kernel_size))
            # print("padding = " + str(padding))
            # print("dilation = " + str(dilation))
            # print("stride = " + str(stride))

            h_out = floor(((h_in + 2 * padding[0] - dilation[0] *
                            (kernel_size[0] - 1) - 1) / stride[0]) + 1)
            w_out = floor(((w_in + 2 * padding[1] - dilation[1] *
                            (kernel_size[1] - 1) - 1) / stride[1]) + 1)
            # print(h_out , w_out)
            return h_out, w_out

        def pool_size(h_in, w_in):

            if "kernel_size" in self.pooling_params.keys():
                kernel_size = (
                    self.pooling_params['kernel_size'],
                    self.pooling_params['kernel_size']) if isinstance(
                        self.pooling_params['kernel_size'],
                        int) else self.pooling_params['kernel_size']
            # else:
            #     kernel_size = (self.kernel_size, self.kernel_size) if isinstance(
            #         self.kernel_size, int) else self.kernel_size

            if "padding" in self.pooling_params.keys():
                padding = (self.pooling_params['padding'],
                           self.pooling_params['padding']) if isinstance(
                               self.pooling_params['padding'],
                               int) else self.pooling_params['padding']
            else:
                padding = (0, 0)

            if "dilation" in self.pooling_params.keys():
                dilation = (self.pooling_params['dilation'],
                            self.pooling_params['dilation']) if isinstance(
                                self.pooling_params['dilation'],
                                int) else self.pooling_params['dilation']
            else:
                dilation = (1, 1)

            if "stride" in self.pooling_params.keys():
                stride = (self.pooling_params['stride'],
                          self.pooling_params['stride']) if isinstance(
                              self.pooling_params['stride'],
                              int) else self.pooling_params['stride']
            else:
                stride = kernel_size

            # print("========pooling===========")
            # print("kernel = " + str(kernel_size))
            # print("padding = " + str(padding))
            # print("dilation = " + str(dilation))
            # print("stride = " + str(stride))

            h_out = floor(((h_in + 2 * padding[0] - dilation[0] *
                            (kernel_size[0] - 1) - 1) / stride[0]) + 1)
            w_out = floor(((w_in + 2 * padding[1] - dilation[1] *
                            (kernel_size[1] - 1) - 1) / stride[1]) + 1)
            # print(h_out , w_out)
            return h_out, w_out

        # Make sure to not mess up the random state.
        rng_state = torch.get_rng_state()
        try:
            # ====== YOUR CODE: ======
            _, in_h, in_w, = tuple(self.in_size)

            from math import floor

            for channel in range(len(self.channels)):
                # print(channel)
                in_h, in_w = conv_size(in_h, in_w)
                if (channel + 1) % self.pool_every == 0:
                    # if channel > 0 and channel % self.pool_every == 0:
                    in_h, in_w = pool_size(in_h, in_w)

            return in_h * in_w * self.channels[-1]

            # return self.channels[-1] * ceil((in_h * in_w) / ((self.pooling_params['kernel_size'] * self.pooling_params['kernel_size']) ** (len(self.channels) // self.pool_every)))
            # return self.channels[-1] * self.conv_params['kernel_size'] * self.conv_params['kernel_size']

            # raise NotImplementedError()
            # ========================
        finally:
            torch.set_rng_state(rng_state)
def _randomize_labels(labels, seed):
    rng_state = torch.get_rng_state()
    torch.manual_seed(seed)
    labels = list(torch.IntTensor(labels)[torch.randperm(len(labels))])
    torch.set_rng_state(rng_state)
    return labels
Beispiel #19
0
def _checkpoint_without_reentrant(function,
                                  preserve_rng_state=True,
                                  *args,
                                  **kwargs):
    """Checkpointining without re-entrant autograd
    Args:
        function: describes what to run in the forward pass of the model or
            part of the model. It should also know how to handle the inputs
            passed as the tuple. For example, in LSTM, if user passes
            ``(activation, hidden)``, :attr:`function` should correctly use the
            first input as ``activation`` and the second input as ``hidden``
        preserve_rng_state(bool, optional):  Omit stashing and restoring
            the RNG state during each checkpoint.
            Default: ``True``
        *args: Arguments to pass in to the given ``function``.
        **kwargs: Keyword arguments to pass into the given ``function``.
    """
    # Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu.
    gpu_autocast_kwargs, cpu_autocast_kwargs = _get_autocast_kwargs()

    if preserve_rng_state:
        fwd_cpu_state = torch.get_rng_state()
        # Don't eagerly initialize the cuda context by accident.
        # (If the user intends that the context is initialized later, within their
        # run_function, we SHOULD actually stash the cuda state here.  Unfortunately,
        # we have no way to anticipate this will happen before we run the function.
        # If they do so, we raise an error.)
        had_cuda_in_fwd = False
        if torch.cuda._initialized:
            had_cuda_in_fwd = True
            fwd_gpu_devices, fwd_gpu_states = get_device_states(*args)

    # Custom class to be able to take weak references
    class Holder():
        pass

    # The Holder object for each of the saved object is saved directly on the
    # SavedVariable and is cleared when reset_data() is called on it. We MUST make
    # sure that this is the only object having an owning reference to ensure that
    # the Tensor stored in storage is deleted as soon as the corresponding SavedVariable
    # data is cleared.
    storage: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary()
    weak_holder_list = []

    def pack(x):
        # TODO(varal7): Instead of returning abstract object, we can return things metadata (such as
        # size, device, ...) to catch certain cases of undeterministic behavior of the forward
        res = Holder()
        weak_holder_list.append(weakref.ref(res))
        return res

    def unpack(x):
        unpack_counter = 0
        if len(storage) == 0:

            def inner_pack(inner):
                nonlocal unpack_counter
                unpack_counter += 1
                # If the holder went out of scope, the SavedVariable is dead and so
                # the value will never be read from the storage. Skip filling it.
                if weak_holder_list[unpack_counter - 1]() is None:
                    return
                # Use detach here to ensure we don't keep the temporary autograd
                # graph created during the second forward
                storage[weak_holder_list[unpack_counter -
                                         1]()] = inner.detach()
                return

            def inner_unpack(packed):
                raise RuntimeError(
                    "You are calling backwards on a tensor that is never exposed. Please open an issue."
                )

            # Stash the surrounding rng state, and mimic the state that was
            # present at this time during forward.  Restore the surrounding state
            # when we're done.
            rng_devices = []
            if preserve_rng_state and had_cuda_in_fwd:
                rng_devices = fwd_gpu_devices
            with torch.random.fork_rng(devices=rng_devices,
                                       enabled=preserve_rng_state):
                if preserve_rng_state:
                    torch.set_rng_state(fwd_cpu_state)
                    if had_cuda_in_fwd:
                        set_device_states(fwd_gpu_devices, fwd_gpu_states)

                with torch.enable_grad(), \
                     torch.cuda.amp.autocast(**gpu_autocast_kwargs), \
                     torch.cpu.amp.autocast(**cpu_autocast_kwargs), \
                     torch.autograd.graph.saved_tensors_hooks(inner_pack, inner_unpack):
                    _unused = function(*args, **kwargs)

        if x not in storage:
            raise RuntimeError(
                "Attempt to retrieve a tensor saved by autograd multiple times without checkpoint"
                " recomputation being triggered in between, this is not currently supported. Please"
                " open an issue with details on your use case so that we can prioritize adding this."
            )

        return storage[x]

    with torch.autograd.graph.saved_tensors_hooks(pack, unpack):
        output = function(*args, **kwargs)
        if torch.cuda._initialized and preserve_rng_state and not had_cuda_in_fwd:
            # Cuda was not initialized before running the forward, so we didn't
            # stash the CUDA state.
            raise RuntimeError(
                "PyTorch's CUDA state was initialized in the forward pass "
                "of a Checkpoint, which is not allowed. Please open an issue "
                "if you need this feature.")

    return output
 def setUp(self):
     if os.getenv("UNLOCK_SEED") is None or os.getenv(
             "UNLOCK_SEED").lower() == "false":
         self.rng_state = torch.get_rng_state()
         torch.manual_seed(0)
def evaluate(args,
             model,
             tokenizer,
             prefix="",
             test=False,
             mc=True,
             generation=False):

    eval_task_names = (args.task_name, )
    eval_outputs_dirs = (args.output_dir, )

    results = {}
    for eval_task, eval_output_dir in zip(eval_task_names, eval_outputs_dirs):
        eval_dataset = load_and_cache_examples(args,
                                               eval_task,
                                               tokenizer,
                                               evaluate=not test,
                                               test=test)

        if not os.path.exists(eval_output_dir) and args.local_rank in [-1, 0]:
            os.makedirs(eval_output_dir)

        args.eval_batch_size = args.per_gpu_eval_batch_size * max(
            1, args.n_gpu)
        # Note that DistributedSampler samples randomly
        eval_sampler = SequentialSampler(
            eval_dataset) if args.local_rank == -1 else DistributedSampler(
                eval_dataset)
        eval_dataloader = DataLoader(eval_dataset,
                                     sampler=eval_sampler,
                                     batch_size=args.eval_batch_size)

        # Eval!
        logger.info("***** Running evaluation {} *****".format(prefix))
        logger.info("  Num examples = %d", len(eval_dataset))
        logger.info("  Batch size = %d", args.eval_batch_size)
        eval_avg_ppl = 0.0
        eval_ppl = 0.0
        nb_eval_steps = 1e-8
        nb_eval_time_steps = 1e-8
        preds = None
        preds_normalized = None
        preds_mc_head = None
        out_label_ids = None
        if generation:
            output_file = open(
                os.path.join(eval_output_dir, "dev_generated_ans.tsv"), 'w')
        random_state = torch.get_rng_state()
        for batch in tqdm(eval_dataloader, desc="Evaluating"):
            model.eval()
            batch = tuple(t.to(args.device) for t in batch)

            with torch.no_grad():
                inputs = {
                    'input_ids':
                    batch[0][:, 1:, :].contiguous().view(-1, batch[0].size(2)),
                    'output_ids':
                    batch[1][:, 1:, :].contiguous().view(-1, batch[1].size(2)),
                    'attention_mask':
                    batch[2][:, 1:, :].contiguous().view(-1, batch[2].size(2)),
                    'token_type_ids':
                    batch[3] if args.model_type in ['xlnet'] else
                    None,  # XLM don't use segment_ids
                }

                mc_inputs = {
                    'input_ids':
                    batch[0].view(-1, batch[0].size(2)),
                    'output_ids':
                    batch[5].view(-1, batch[5].size(2)),
                    'attention_mask':
                    batch[2].view(-1, batch[2].size(2)),
                    'token_type_ids':
                    batch[3] if args.model_type in ['xlnet'] else
                    None,  # XLM don't use segment_ids
                }
                #output_ids = mc_inputs['input_ids'].clone().detach()
                #output_ids = output_ids[:, 1:]
                #output_ids[output_ids == tokenizer.pad_token_id] = -1
                #output_ids[output_ids == tokenizer.cls_token_id] = -1
                #mc_inputs['output_ids'] = output_ids
                outputs = model(**inputs)

                if mc:
                    logits_mc_outputs = model(**mc_inputs)
                    mc_outputs = logits_mc_outputs
                    mc_score = (-1 * mc_outputs.sum(1)).view(
                        -1, batch[0].size(1)).data.cpu().numpy()
                    mc_prob = F.softmax((-1 * mc_outputs.sum(1)).view(-1, batch[0].size(1)),
                                        -1).data.cpu().numpy()
                    mc_score_normalized = mc_outputs.sum(1) / (
                        (mc_outputs != 0).float().sum(1))
                    mc_score_normalized = (mc_score_normalized * -1).view(
                        -1, batch[0].size(1)).data.cpu().numpy()

                    if preds is None:
                        preds = mc_score
                        preds_normalized = mc_score_normalized
                        out_label_ids = batch[4].detach().cpu().numpy()

                    else:
                        preds = np.append(preds, mc_score, axis=0)
                        preds_normalized = np.append(preds_normalized,
                                                     mc_score_normalized,
                                                     axis=0)

                        out_label_ids = np.append(
                            out_label_ids,
                            batch[4].detach().cpu().numpy(),
                            axis=0)
                #output_ids = mc_inputs['input_ids'].clone().detach()
                #output_ids = output_ids[:, 1:]
                #output_ids[output_ids == tokenizer.pad_token_id] = -1
                #output_ids[output_ids == tokenizer.cls_token_id] = -1
                #mc_inputs['output_ids'] = output_ids
                outputs = model(**inputs)

                batch_total_ppl = torch.exp(
                    outputs.sum(1) / ((outputs != 0).float().sum(1))).sum()
                batch_total_loss = outputs.sum()

                eval_avg_ppl += batch_total_ppl.item()
                eval_ppl += batch_total_loss.item()
                if generation:
                    distractor_size = 3

                    sep_id = tokenizer.eos_token_id
                    for i in range(0, inputs['input_ids'].size(0),
                                   distractor_size):
                        tmp_output_ids = inputs['output_ids'][i, :]

                        question_length = (tmp_output_ids !=
                                           -1).nonzero().min() + 1
                        distractors = set({})
                        for _ in range(distractor_size):
                            question, distractor = model.generate(
                                inputs['input_ids'][i:i + 1, :question_length],
                                30,
                                sample=True,
                                tmp=1.0,
                                top_p=1.0,
                                label=inputs['input_ids'][i:i +
                                                          1, question_length:])
                            while distractor in distractors:
                                question, distractor = model.generate(
                                    inputs['input_ids'][i:i +
                                                        1, :question_length],
                                    30,
                                    sample=True,
                                    tmp=1.0,
                                    top_p=1.0,
                                    label=inputs['input_ids']
                                    [i:i + 1, question_length:])
                            distractors.add(distractor)

                        res = [question] + list(distractors)
                        output_file.write("\t".join(res) + "\n")

            nb_eval_steps += outputs.size(0)
            nb_eval_time_steps += ((outputs != 0).float().sum()).item()
        torch.set_rng_state(random_state)
        if generation:
            output_file.close()

        eval_avg_ppl = eval_avg_ppl / nb_eval_steps
        eval_ppl = exp(eval_ppl / nb_eval_time_steps)
        if mc:
            preds = F.softmax(torch.tensor(preds), -1).detach().cpu().numpy()
            np.save(os.path.join(eval_output_dir, "preds.npy"), preds)
            preds = np.argmax(preds, axis=1)
            preds_normalized = np.argmax(preds_normalized, axis=1)
            acc = simple_accuracy(preds, out_label_ids)
            acc_normalized = simple_accuracy(preds_normalized, out_label_ids)

            result = {
                "eval_ppl": eval_ppl,
                "eval_avg_ppl": eval_avg_ppl,
                "eval_acc": acc,
                "eval_acc_normalized": acc_normalized
            }
        else:
            result = {"eval_ppl": eval_ppl, "eval_avg_ppl": eval_avg_ppl}
        results.update(result)
        output_eval_file = os.path.join(
            eval_output_dir,
            "is_test_" + str(test).lower() + "_eval_results.txt")

        with open(output_eval_file, "w") as writer:
            logger.info("***** Eval results {} *****".format(
                str(prefix) + " is test:" + str(test)))
            writer.write("model           =%s\n" %
                         str(args.model_name_or_path))
            writer.write("total batch size=%d\n" %
                         (args.per_gpu_train_batch_size *
                          args.gradient_accumulation_steps *
                          (torch.distributed.get_world_size()
                           if args.local_rank != -1 else 1)))
            writer.write("train num epochs=%d\n" % args.num_train_epochs)
            writer.write("fp16            =%s\n" % args.fp16)
            writer.write("max seq length  =%d\n" % args.input_max_seq_length)
            for key in sorted(result.keys()):
                logger.info("  %s = %s", key, str(result[key]))
                writer.write("%s = %s\n" % (key, str(result[key])))
    return results
Beispiel #22
0
    def backward(ctx, *args):
        global timers
        #see_memory_usage("In backward", force=True)
        #removing pointers to the contiguous buffer memory
        #so that they can be garbage collected once the checkpoints
        #have been used
        if SYNCHRONIZE:
            torch.cuda.synchronize()
        if PROFILE_TIME:
            timers('backward').start()

        if CONTIGUOUS_CHECKPOINTING:
            global data_offsets, size_offsets
            global contiguous_data_buffers, contiguous_size_buffers

            for buffers in contiguous_data_buffers:
                buffers = []

            #frees up all the pointers to the checkpoints except for the ones
            #stored by save for backward
            contiguous_data_buffers = []
            contiguous_size_buffers = []
            data_offsets = []
            size_offsets = []

        #see_memory_usage("In backward checkpointing code", force=True)
        if not torch.autograd._is_checkpoint_valid():
            raise RuntimeError("Checkpointing is not compatible with .grad(), "
                               "please use .backward() if possible")

        global cuda_device, transport_stream, PARTITION_ACTIVATIONS

        if PARTITION_ACTIVATIONS:
            #with torch.cuda.stream(transport_stream):
            inputs = get_full_inputs(ctx.saved_tensors,
                                     device=cuda_device if PA_TO_CPU else None)
            detached_inputs = detach_variable(inputs)
        else:
            inputs = ctx.saved_tensors
            detached_inputs = detach_variable(inputs)

        # Store the current states.
        bwd_cpu_rng_state = torch.get_rng_state()
        bwd_cuda_rng_state = torch.cuda.get_rng_state()
        bwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()

        # Set the states to what it used to be before the forward pass.
        torch.set_rng_state(ctx.fwd_cpu_rng_state)
        _set_cuda_rng_state(ctx.fwd_cuda_rng_state)
        get_cuda_rng_tracker().set_states(ctx.fwd_cuda_rng_state_tracker)

        # if PARTITION_ACTIVATIONS:
        #     current_stream=torch.cuda.current_stream()
        #     current_stream.wait_stream(transport_stream)

        with torch.enable_grad():
            outputs = ctx.run_function(*detached_inputs)

        # Set the states back to what it was at the start of this function.
        torch.set_rng_state(bwd_cpu_rng_state)
        _set_cuda_rng_state(bwd_cuda_rng_state)
        get_cuda_rng_tracker().set_states(bwd_cuda_rng_state_tracker)

        if isinstance(outputs, torch.Tensor):
            outputs = (outputs, )
        torch.autograd.backward(outputs, args)

        if PROFILE_TIME:
            timers('backward').stop()
            timers.log(['backward'])
        if SYNCHRONIZE:
            torch.cuda.synchronize()
        return (None, ) + tuple(inp.grad for inp in detached_inputs)
Beispiel #23
0
def main():
    parser = argparse.ArgumentParser()
    # path setting
    parser.add_argument("--waveforms", required=True,
                        type=str, help="directory or list of wav files")
    parser.add_argument("--waveforms_eval", required=True,
                        type=str, help="directory or list of evaluation wav files")
    parser.add_argument("--feats", required=True,
                        type=str, help="directory or list of aux feat files")
    parser.add_argument("--feats_eval", required=True,
                        type=str, help="directory or list of evaluation aux feat files")
    parser.add_argument("--stats", required=True,
                        type=str, help="hdf5 file including statistics")
    parser.add_argument("--expdir", required=True,
                        type=str, help="directory to save the model")
    # network structure setting
    parser.add_argument("--n_aux", default=54,
                        type=int, help="number of dimension of aux feats")
    parser.add_argument("--skip_chn", default=256,
                        type=int, help="number of channels of skip output")
    parser.add_argument("--seg", default=1,
                        type=int, help="segment size")
    parser.add_argument("--dilation_depth", default=3,
                        type=int, help="depth of dilation")
    parser.add_argument("--dilation_repeat", default=2,
                        type=int, help="repeat of dilation depth")
    parser.add_argument("--hid_chn", default=192,
                        type=int, help="kernel size of dilated causal convolution")
    parser.add_argument("--kernel_size", default=7,
                        type=int, help="kernel size of dilated causal convolution")
    parser.add_argument("--aux_kernel_size", default=3,
                        type=int, help="kernel size of dilated causal convolution")
    parser.add_argument("--aux_dilation_size", default=2,
                        type=int, help="kernel size of dilated causal convolution")
    parser.add_argument("--upsampling_factor", default=110,
                        type=int, help="upsampling factor of aux features"
                                       "(if set 0, do not apply)")
    parser.add_argument("--n_fft_facts", default=17,
                        type=int, help="kernel size of dilated causal convolution")
    parser.add_argument("--string_path", default="/feat_org_lf0",
                        type=str, help="directory to save the model")
    # network training setting
    parser.add_argument("--lr", default=1e-4,
                        type=float, help="learning rate")
    parser.add_argument("--batch_size", default=8800,
                        type=int, help="batch size (if set 0, utterance batch will be used)")
    parser.add_argument("--epoch_count", default=4000,
                        type=int, help="number of training epochs")
    parser.add_argument("--do_prob", default=0,
                        type=float, help="dropout probability")
    parser.add_argument("--lpc", default=0,
                        type=int, help="number of linear predictive coefficients for location estimate")
    parser.add_argument("--aux_conv2d_flag", default=False,
                        type=strtobool, help="flag to use 2d conv of aux")
    parser.add_argument("--wav_conv_flag", default=False,
                        type=strtobool, help="flag to use 1d conv of wav")
    # other setting
    parser.add_argument("--seed", default=1,
                        type=int, help="seed number")
    parser.add_argument("--resume", default=None,
                        type=str, help="model path to restart training")
    parser.add_argument("--pretrained", default=None,
                        type=str, help="model path to restart training")
    parser.add_argument("--GPU_device", default=None,
                        type=int, help="selection of GPU device")
    parser.add_argument("--verbose", default=1,
                        type=int, help="log level")
    args = parser.parse_args()

    if args.GPU_device is not None:
        os.environ["CUDA_DEVICE_ORDER"]     = "PCI_BUS_ID"
        os.environ["CUDA_VISIBLE_DEVICES"]  = str(args.GPU_device)

    # make experimental directory
    if not os.path.exists(args.expdir):
        os.makedirs(args.expdir)

    # set log level
    if args.verbose == 1:
        logging.basicConfig(level=logging.INFO,
                            format='%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s',
                            datefmt='%m/%d/%Y %I:%M:%S',
                            filename=args.expdir + "/train.log")
        logging.getLogger().addHandler(logging.StreamHandler())
    elif args.verbose > 1:
        logging.basicConfig(level=logging.DEBUG,
                            format='%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s',
                            datefmt='%m/%d/%Y %I:%M:%S',
                            filename=args.expdir + "/train.log")
        logging.getLogger().addHandler(logging.StreamHandler())
    else:
        logging.basicConfig(level=logging.WARN,
                            format='%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s',
                            datefmt='%m/%d/%Y %I:%M:%S',
                            filename=args.expdir + "/train.log")
        logging.getLogger().addHandler(logging.StreamHandler())
        logging.warn("logging is disabled.")

    # fix seed
    os.environ['PYTHONHASHSEED'] = str(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    torch.backends.cudnn.benchmark = True #faster
    #torch.backends.cudnn.deterministic = True #reproducibility_slower
    #torch.backends.cudnn.benchmark = False #reproducibility_slower

    # save args as conf
    torch.save(args, args.expdir + "/model.conf")

    # define network
    model = CSWNV(
        n_aux=args.n_aux,
        skip_chn=args.skip_chn,
        hid_chn=args.hid_chn,
        dilation_depth=args.dilation_depth,
        dilation_repeat=args.dilation_repeat,
        kernel_size=args.kernel_size,
        aux_kernel_size=args.aux_kernel_size,
        aux_dilation_size=args.aux_dilation_size,
        do_prob=args.do_prob,
        seg=args.seg,
        lpc=args.lpc,
        aux_conv2d_flag=args.aux_conv2d_flag,
        wav_conv_flag=args.wav_conv_flag,
        upsampling_factor=args.upsampling_factor)
    logging.info(model)
    criterion_lsd = LSDloss()
    criterion_laplace = LaplaceLoss()

    # define transforms
    string_path_name = args.string_path.split('feat_')[1]
    logging.info(string_path_name)
    scaler = StandardScaler()
    if check_hdf5(args.stats, "/mean_"+string_path_name):
        scaler.mean_ = read_hdf5(args.stats, "/mean_"+string_path_name)
        scaler.scale_ = read_hdf5(args.stats, "/scale_"+string_path_name)
    elif check_hdf5(args.stats, "/mean_"+args.string_path):
        scaler.mean_ = read_hdf5(args.stats, "/mean_"+args.string_path)
        scaler.scale_ = read_hdf5(args.stats, "/scale_"+args.string_path)
    else:
        scaler.mean_ = read_hdf5(args.stats, "/mean_feat_"+string_path_name)
        scaler.scale_ = read_hdf5(args.stats, "/scale_feat_"+string_path_name)
    mean_src = torch.FloatTensor(scaler.mean_)
    std_src = torch.FloatTensor(scaler.scale_)

    # send to gpu
    if torch.cuda.is_available():
        model.cuda()
        criterion_lsd.cuda()
        criterion_laplace.cuda()
        mean_src = mean_src.cuda()
        std_src = std_src.cuda()
    else:
        logging.error("gpu is not available. please check the setting.")
        sys.exit(1)

    model.train()
    model.apply(initialize)
    model.scale_in.weight = torch.nn.Parameter(torch.unsqueeze(torch.diag(1.0/std_src.data),2))
    model.scale_in.bias = torch.nn.Parameter(-(mean_src.data/std_src.data))

    for param in model.parameters():
        param.requires_grad = True
    for param in model.scale_in.parameters():
        param.requires_grad = False

    parameters = filter(lambda p: p.requires_grad, model.parameters())
    parameters = sum([np.prod(p.size()) for p in parameters]) / 1000000
    logging.info('Trainable Parameters: %.3f million' % parameters)

    module_list = list(model.conv_aux.parameters())
    module_list += list(model.upsampling.parameters())
    if model.aux_conv2d_flag and model.seg > 1:
        module_list += list(model.aux_conv2d.parameters())
    if model.wav_conv_flag:
        module_list += list(model.wav_conv.parameters())
    module_list += list(model.causal.parameters()) + list(model.in_x.parameters())
    module_list += list(model.dil_h.parameters()) + list(model.out_skip.parameters())
    module_list += list(model.out_1.parameters()) + list(model.out_2.parameters())
    optimizer = torch.optim.Adam(module_list, lr=args.lr)

    # resume
    if args.pretrained is not None:
        checkpoint = torch.load(args.pretrained)
        model.load_state_dict(checkpoint["model"])
        epoch_idx = checkpoint["iterations"]
        logging.info("pretrained from %d-iter checkpoint." % epoch_idx)
        epoch_idx = 0
    elif args.resume is not None:
        checkpoint = torch.load(args.resume)
        model.load_state_dict(checkpoint["model"])
        optimizer.load_state_dict(checkpoint["optimizer"])
        epoch_idx = checkpoint["iterations"]
        logging.info("restored from %d-iter checkpoint." % epoch_idx)
    else:
        epoch_idx = 0

    # define generator training
    if os.path.isdir(args.waveforms):
        filenames = sorted(find_files(args.waveforms, "*.wav", use_dir_name=False))
        wav_list = [args.waveforms + "/" + filename for filename in filenames]
        feat_list = [args.feats + "/" + filename.replace(".wav", ".h5") for filename in filenames]
    elif os.path.isfile(args.waveforms):
        wav_list = read_txt(args.waveforms)
        feat_list = read_txt(args.feats)
    else:
        logging.error("--waveforms should be directory or list.")
        sys.exit(1)
    assert len(wav_list) == len(feat_list)
    logging.info("number of training data = %d." % len(wav_list))
    if args.pretrained is None:
        generator = train_generator(
            wav_list, feat_list,
            model.receptive_field,
            string_path=args.string_path,
            seg=model.seg,
            batch_size=args.batch_size,
            training=True,
            upsampling_factor=args.upsampling_factor)
    else:
        generator = train_generator(
            wav_list, feat_list,
            model.receptive_field,
            string_path=args.string_path,
            seg=model.seg,
            batch_size=args.batch_size,
            training=True,
            upsampling_factor=args.upsampling_factor)

    # define generator evaluation
    if os.path.isdir(args.waveforms_eval):
        filenames_eval = sorted(find_files(args.waveforms_eval, "*.wav", use_dir_name=False))
        wav_list_eval = [args.waveforms_eval + "/" + filename for filename in filenames_eval]
        feat_list_eval = [args.feats_eval + "/" + filename.replace(".wav", ".h5") \
                            for filename in filenames_eval]
    elif os.path.isfile(args.waveforms_eval):
        wav_list_eval = read_txt(args.waveforms_eval)
        feat_list_eval = read_txt(args.feats_eval)
    else:
        logging.error("--waveforms_eval should be directory or list.")
        sys.exit(1)
    logging.info("number of evaluation data = %d." % len(wav_list_eval))
    assert len(wav_list_eval) == len(feat_list_eval)
    if args.pretrained is None:
        generator_eval = train_generator(
            wav_list_eval, feat_list_eval,
            model.receptive_field,
            string_path=args.string_path,
            seg=model.seg,
            batch_size=args.batch_size,
            training=False,
            upsampling_factor=args.upsampling_factor)
    else:
        generator_eval = train_generator(
            wav_list_eval, feat_list_eval,
            model.receptive_field,
            string_path=args.string_path,
            seg=model.seg,
            batch_size=args.batch_size,
            training=False,
            upsampling_factor=args.upsampling_factor)

    # train
    loss_laplace = []
    loss_err = []
    loss_lsd = []
    fft_facts = []
    init_fft = 64
    hann_win = [None]*args.n_fft_facts
    if args.n_fft_facts == 5:
        fft_facts = [128, 256, 512, 1024, 2048]
        for i in range(args.n_fft_facts):
            hann_win[i] = torch.hann_window(fft_facts[i]).cuda()
    elif args.n_fft_facts == 9:
        fft_facts = [128, 192, 256, 384, 512, 768, 1024, 1536, 2048]
        for i in range(args.n_fft_facts):
            hann_win[i] = torch.hann_window(fft_facts[i]).cuda()
    elif args.n_fft_facts == 17:
        fft_facts = [128, 160, 192, 224, 256, 320, 384, 448, 512, 640, 768, 896, 1024, 1280, 1536, 1792, 2048]
        for i in range(args.n_fft_facts):
            hann_win[i] = torch.hann_window(fft_facts[i]).cuda()
    else:
        for i in range(args.n_fft_facts):
            if i % 2 == 0:
                init_fft *= 2
                fft_facts.append(init_fft)
            else:
                fft_facts.append(init_fft+int(init_fft/2))
            hann_win[i] = torch.hann_window(fft_facts[i]).cuda()
    logging.info(fft_facts)
    batch_stft_loss = [None]*args.n_fft_facts
    stft_out = [None]*args.n_fft_facts
    stft_trg = [None]*args.n_fft_facts
    total = 0
    iter_idx = 0
    iter_count = 0
    min_eval_loss_lsd = 99999999.99
    min_eval_loss_laplace = 99999999.99
    min_eval_loss_err = 99999999.99
    min_eval_loss_lsd_std = 99999999.99
    min_eval_loss_laplace_std = 99999999.99
    min_eval_loss_err_std = 99999999.99
    min_idx = -1
    if args.resume is not None:
        np.random.set_state(checkpoint["numpy_random_state"])
        torch.set_rng_state(checkpoint["torch_random_state"])
    logging.info("==%d EPOCH==" % (epoch_idx+1))
    logging.info("Training data")
    #args.epoch_count = 5300
    while epoch_idx < args.epoch_count:
        start = time.time()
        batch_x_float, batch_h, c_idx, utt_idx, wavfile, h_bs, x_bs, h_ss, x_ss = next(generator)
        if c_idx < 0: # summarize epoch
            # save current epoch model
            numpy_random_state = np.random.get_state()
            torch_random_state = torch.get_rng_state()
            save_checkpoint(args.expdir, model, optimizer, numpy_random_state, torch_random_state, epoch_idx+1)
            # report current epoch
            logging.info("(EPOCH:%d) average training loss = %.6f (+- %.6f) %.6f dB (+- %.6f dB) %.6f "\
                            "(+- %.6f) (%.3f min., %.3f sec / batch)" % (
                epoch_idx + 1, np.mean(loss_laplace), np.std(loss_laplace), np.mean(loss_lsd), \
                    np.std(loss_lsd), np.mean(loss_err), np.std(loss_err), total / 60.0, total / iter_count))
            logging.info("estimated training required time = {0.days:02}:{0.hours:02}:{0.minutes:02}:"\
            "{0.seconds:02}".format(relativedelta(seconds=int((args.epoch_count - (epoch_idx + 1)) * total))))
            # compute loss in evaluation data
            loss_lsd = []
            loss_err = []
            loss_laplace = []
            total = 0
            iter_count = 0
            model.eval()
            for param in model.parameters():
                param.requires_grad = False
            logging.info("Evaluation data")
            while True:
                with torch.no_grad():
                    start = time.time()
                    batch_x_float, batch_h, c_idx, utt_idx, wavfile, h_bs, x_bs, h_ss, x_ss = \
                        next(generator_eval)
                    if c_idx < 0:
                        break

                    tf = batch_h.shape[0]
                    ts = batch_x_float.shape[0]

                    batch_h = batch_h[h_ss:]
                    batch_x_ = batch_x_float[x_ss:]
                    if model.lpc > 0:
                        if x_ss+model.lpc_offset >= 0:
                            batch_x_lpc = batch_x_float[x_ss+model.lpc_offset:]
                        else:
                            batch_x_lpc = batch_x_float[x_ss:]
                    if h_bs != -1:
                        batch_h = batch_h[:h_bs]
                        if model.lpc > 0:
                            if x_ss+model.lpc_offset >= 0:
                                batch_x_prob = batch_x_lpc[:x_bs-model.lpc_offset].unsqueeze(0)
                            else:
                                batch_x_prob = F.pad(batch_x_lpc[:x_bs], (-(x_ss+model.lpc_offset), 0), \
                                                        'constant', 0).unsqueeze(0)
                        batch_x = batch_x_[:x_bs-model.seg]
                        batch_x_float = batch_x_[model.seg:x_bs]
                    else:
                        if model.lpc > 0:
                            if x_ss+model.lpc_offset > 0:
                                batch_x_prob = batch_x_lpc.unsqueeze(0)
                            else:
                                batch_x_prob = F.pad(batch_x_lpc, (-(x_ss+model.lpc_offset)), \
                                                    'constant', 0).unsqueeze(0)
                        batch_x = batch_x_[:-model.seg]
                        batch_x_float = batch_x_[model.seg:]
                    batch_h = batch_h.transpose(0,1).unsqueeze(0)
                    batch_x = batch_x.unsqueeze(0).unsqueeze(1)
                    if h_ss > 0:
                        feat_len = batch_x_float[model.receptive_field:].shape[0]
                    else:
                        feat_len = batch_x_float.shape[0]

                    if model.lpc > 0:
                        mus, bs, log_bs, ass = model(batch_h, batch_x)
                        # jump off s samples as in synthesis
                        mus = mus[:,::model.seg,:]
                        bs = bs[:,::model.seg,:]
                        log_bs = log_bs[:,::model.seg,:]
                        ass = ass[:,::model.seg,:].flip(-1)
                        init_mus = mus
                        for j in range(model.seg):
                            tmp_smpls = batch_x_prob[:,j:-(model.seg-j)].unfold(1, model.lpc, model.seg)
                            lpc = torch.sum(ass*tmp_smpls,-1,keepdim=True)
                            if j > 0:
                                mus = torch.cat((mus, lpc+init_mus[:,:,j:j+1]),2)
                            else:
                                mus = lpc+init_mus[:,:,j:j+1]
                        mus = mus.reshape(mus.shape[0],-1)
                        bs = bs.reshape(bs.shape[0],-1)
                        log_bs = log_bs.reshape(log_bs.shape[0],-1)
                    else:
                        mus, bs, log_bs = model(batch_h, batch_x)

                    if h_ss > 0:
                        mus = mus[0,model.receptive_field:]
                        bs = bs[0,model.receptive_field:]
                        log_bs = log_bs[0,model.receptive_field:]
                        batch_x_float = batch_x_float[model.receptive_field:]
                    else:
                        mus = mus[0]
                        bs = bs[0]
                        log_bs = log_bs[0]

                    m_sum = 0
                    batch_loss_laplace = criterion_laplace(mus, bs, batch_x_float, log_b=log_bs)
                    eps = torch.empty(mus.shape).cuda().uniform_(-0.4999,0.5)
                    batch_output = mus-bs*eps.sign()*torch.log1p(-2*eps.abs())
                    batch_loss_err = torch.mean(torch.abs(batch_output-batch_x_float))
                    logging.info("%lf %E %lf %E" % (torch.min(batch_x_float), torch.mean(batch_x_float), \
                                                    torch.max(batch_x_float), torch.var(batch_x_float)))
                    logging.info("%lf %E %lf %E" % (torch.min(batch_output), torch.mean(batch_output), \
                                                    torch.max(batch_output), torch.var(batch_output)))
                    m = 0
                    for i in range(args.n_fft_facts):
                        if feat_len > int(fft_facts[i]/2):
                            stft_out[i] = torch.stft(batch_output, fft_facts[i], window=hann_win[i])
                            stft_trg[i] = torch.stft(batch_x_float, fft_facts[i], window=hann_win[i])
                            tmp_batch_stft_loss = criterion_lsd(stft_out[i], stft_trg[i])
                            if not torch.isinf(tmp_batch_stft_loss) and not torch.isnan(tmp_batch_stft_loss):
                                if m > 0:
                                    batch_loss_lsd = torch.cat((batch_loss_lsd, \
                                                                tmp_batch_stft_loss.unsqueeze(0)))
                                else:
                                    batch_loss_lsd = tmp_batch_stft_loss.unsqueeze(0)
                                m += 1

                    loss_err.append(batch_loss_err.item())
                    loss_laplace.append(batch_loss_laplace.item())
                    if m > 0:
                        batch_loss_lsd = torch.mean(batch_loss_lsd)
                        loss_lsd.append(batch_loss_lsd.item())
                        logging.info("batch eval loss %s [%d:%d] %d %d %d %d %d %d = %.3f %.3f dB %.6f "\
                            "(%.3f sec)" % (os.path.join(os.path.basename(os.path.dirname(wavfile)),\
                            os.path.basename(wavfile)), c_idx+1, utt_idx+1, tf, ts, h_ss, h_bs, x_ss, x_bs, \
                            batch_loss_laplace.item(), batch_loss_lsd.item(), \
                            batch_loss_err.item(), time.time() - start))
                    else:
                        logging.info("batch eval loss %s [%d:%d] %d %d %d %d %d %d = %.3f n/a %.6f "\
                            "(%.3f sec)" % (os.path.join(os.path.basename(os.path.dirname(wavfile)),\
                            os.path.basename(wavfile)), c_idx+1, utt_idx+1, tf, ts, h_ss, h_bs, x_ss, x_bs, \
                            batch_loss_laplace.item(), batch_loss_err.item(), time.time() - start))
                    iter_count += 1
                    total += time.time() - start
            eval_loss_lsd = np.mean(loss_lsd)
            eval_loss_lsd_std = np.std(loss_lsd)
            eval_loss_err = np.mean(loss_err)
            eval_loss_err_std = np.std(loss_err)
            eval_loss_laplace = np.mean(loss_laplace)
            eval_loss_laplace_std = np.std(loss_laplace)
            logging.info("(EPOCH:%d) average evaluation loss = %.6f (+- %.6f) %.6f dB (+- %.6f dB) %.6f "\
                "(+- %.6f) (%.3f min., %.3f sec / batch)" % (epoch_idx + 1, eval_loss_laplace, \
                eval_loss_laplace_std, eval_loss_lsd, eval_loss_lsd_std, eval_loss_err, eval_loss_err_std, \
                total / 60.0, total / iter_count))
            if (eval_loss_laplace+eval_loss_laplace_std+eval_loss_lsd+eval_loss_lsd_std+eval_loss_err\
                +eval_loss_err_std) <= (min_eval_loss_laplace+min_eval_loss_laplace_std+min_eval_loss_lsd\
                    +min_eval_loss_lsd_std+min_eval_loss_err+min_eval_loss_err_std):
                min_eval_loss_lsd = eval_loss_lsd
                min_eval_loss_lsd_std = eval_loss_lsd_std
                min_eval_loss_err = eval_loss_err
                min_eval_loss_err_std = eval_loss_err_std
                min_eval_loss_laplace = eval_loss_laplace
                min_eval_loss_laplace_std = eval_loss_laplace_std
                min_idx = epoch_idx
            logging.info("min_eval_loss = %.6f (+- %.6f) %.6f dB (+- %.6f dB) %.6f (+- %.6f) min_idx=%d" % (
                min_eval_loss_laplace, min_eval_loss_laplace_std, min_eval_loss_lsd, min_eval_loss_lsd_std, \
                min_eval_loss_err, min_eval_loss_err_std, min_idx+1))
            loss_lsd = []
            loss_laplace = []
            loss_err = []
            total = 0
            iter_count = 0
            epoch_idx += 1
            np.random.set_state(numpy_random_state)
            torch.set_rng_state(torch_random_state)
            model.train()
            for param in model.parameters():
                param.requires_grad = True
            for param in model.scale_in.parameters():
                param.requires_grad = False
            # start next epoch
            if epoch_idx < args.epoch_count:
                start = time.time()
                logging.info("==%d EPOCH==" % (epoch_idx+1))
                logging.info("Training data")
                batch_x_float, batch_h, c_idx, utt_idx, wavfile, h_bs, x_bs, h_ss, x_ss = next(generator)
        # feedforward and backpropagate current batch
        if epoch_idx < args.epoch_count:
            logging.info("%d iteration [%d]" % (iter_idx+1, epoch_idx+1))

            tf = batch_h.shape[0]
            ts = batch_x_float.shape[0]

            batch_h = batch_h[h_ss:]
            batch_x_ = batch_x_float[x_ss:]
            if model.lpc > 0:
                if x_ss+model.lpc_offset >= 0:
                    batch_x_lpc = batch_x_float[x_ss+model.lpc_offset:]
                else:
                    batch_x_lpc = batch_x_float[x_ss:]
            if h_bs != -1:
                batch_h = batch_h[:h_bs]
                if model.lpc > 0:
                    if x_ss+model.lpc_offset >= 0:
                        batch_x_prob = batch_x_lpc[:x_bs-model.lpc_offset].unsqueeze(0)
                    else:
                        batch_x_prob = F.pad(batch_x_lpc[:x_bs], (-(x_ss+model.lpc_offset), 0), \
                                                'constant', 0).unsqueeze(0)
                batch_x = batch_x_[:x_bs-model.seg]
                batch_x_float = batch_x_[model.seg:x_bs]
            else:
                if model.lpc > 0:
                    if x_ss+model.lpc_offset > 0:
                        batch_x_prob = batch_x_lpc.unsqueeze(0)
                    else:
                        batch_x_prob = F.pad(batch_x_lpc, (-(x_ss+model.lpc_offset)), \
                                                'constant', 0).unsqueeze(0)
                batch_x = batch_x_[:-model.seg]
                batch_x_float = batch_x_[model.seg:]
            batch_h = batch_h.transpose(0,1).unsqueeze(0)
            batch_x = batch_x.unsqueeze(0).unsqueeze(1)
            if h_ss > 0:
                if model.seg > 1:
                    feat_len = batch_x_float[model.receptive_field:-(model.seg-1)].shape[0]
                else:
                    feat_len = batch_x_float[model.receptive_field:].shape[0]
            else:
                if model.seg > 1:
                    feat_len = batch_x_float[:-(model.seg-1)].shape[0]
                else:
                    feat_len = batch_x_float.shape[0]

            if model.lpc > 0:
                mus, bs_noclip, bs, log_bs, ass = model(batch_h, batch_x, do=True, clip=True)
                ass = ass.flip(-1)
                init_mus = mus
                for j in range(model.seg):
                    tmp_smpls = batch_x_prob[:,j:-(model.seg-j)].unfold(1, model.lpc, 1)
                    lpc = torch.sum(ass*tmp_smpls,-1,keepdim=True)
                    if j > 0:
                        mus = torch.cat((mus, lpc+init_mus[:,:,j:j+1]),2)
                    else:
                        mus = lpc+init_mus[:,:,j:j+1]
                if model.seg == 1:
                    mus = mus.reshape(mus.shape[0], -1)
                    bs_noclip = bs_noclip.reshape(mus.shape[0], -1)
                    bs = bs.reshape(mus.shape[0], -1)
                    log_bs = log_bs.reshape(mus.shape[0], -1)
            else:
                mus, bs_noclip, bs, log_bs = model(batch_h, batch_x, do=True, clip=True)

            if h_ss > 0:
                mus = mus[0,model.receptive_field:]
                bs_noclip = bs_noclip[0,model.receptive_field:]
                bs = bs[0,model.receptive_field:]
                log_bs = log_bs[0,model.receptive_field:]
                batch_x_float = batch_x_float[model.receptive_field:]
            else:
                mus = mus[0]
                bs_noclip = bs_noclip[0]
                bs = bs[0]
                log_bs = log_bs[0]

            m_sum = 0
            if model.seg > 1:
                n_sum = 0
                for i in range(model.seg):
                    if i > 0:
                        i_n = i+1
                        mus_i = mus[:,i:i_n].squeeze(-1)
                        bs_noclip_i = bs_noclip[:,i:i_n].squeeze(-1)
                        if i_n < model.seg:
                            batch_x_float_i = batch_x_float[i:-(model.seg-(i_n))]
                        else:
                            batch_x_float_i = batch_x_float[i:]
                        tmp_batch_loss_laplace = criterion_laplace(mus_i, bs[:,i:i_n].squeeze(-1), \
                                                batch_x_float_i, log_b=log_bs[:,i:i_n].squeeze(-1), log=False)
                        batch_loss_laplace = torch.cat((batch_loss_laplace, \
                                                        tmp_batch_loss_laplace.unsqueeze(0)))
                    else:
                        mus_i = mus[:,:1].squeeze(-1)
                        bs_noclip_i = bs_noclip[:,:1].squeeze(-1)
                        batch_x_float_i = batch_x_float[:-(model.seg-1)]
                        tmp_batch_loss_laplace = criterion_laplace(mus_i, bs[:,:1].squeeze(-1), \
                                                    batch_x_float_i, log_b=log_bs[:,:1].squeeze(-1))
                        batch_loss_laplace = tmp_batch_loss_laplace.unsqueeze(0)
                    eps = torch.empty(mus_i.shape).cuda().uniform_(-0.4999,0.5)
                    batch_output = mus_i-bs_noclip_i*eps.sign()*torch.log1p(-2*eps.abs())
                    tmp_batch_loss_err = torch.mean(torch.abs(batch_output-batch_x_float_i))
                    if i > 0:
                        batch_loss_err = torch.cat((batch_loss_err, tmp_batch_loss_err.unsqueeze(0)))
                    else:
                        batch_loss_err = tmp_batch_loss_err.unsqueeze(0)
                    if i == 0:
                        logging.info("%lf %E %lf %E" % (torch.min(batch_x_float_i), \
                        torch.mean(batch_x_float_i), torch.max(batch_x_float_i), torch.var(batch_x_float_i)))
                        logging.info("%lf %E %lf %E" % (torch.min(batch_output), torch.mean(batch_output), \
                        torch.max(batch_output), torch.var(batch_output)))
                    n = 0
                    for i in range(args.n_fft_facts):
                        if feat_len > int(fft_facts[i]/2):
                            stft_out[i] = torch.stft(batch_output, fft_facts[i], window=hann_win[i])
                            stft_trg[i] = torch.stft(batch_x_float_i, fft_facts[i], window=hann_win[i])
                            tmp_batch_stft_loss = criterion_lsd(stft_out[i], stft_trg[i], LSD=False, L2=False)
                            if not torch.isinf(tmp_batch_stft_loss) and not torch.isnan(tmp_batch_stft_loss):
                                if n > 0:
                                    tmp_batch_loss_stft_l1 = torch.cat((tmp_batch_loss_stft_l1, \
                                                                        tmp_batch_stft_loss.unsqueeze(0)))
                                else:
                                    tmp_batch_loss_stft_l1 = tmp_batch_stft_loss.unsqueeze(0)
                                n += 1
                    if n > 0:
                        if n_sum > 0:
                            batch_loss_stft_l1 = torch.cat((batch_loss_stft_l1, \
                                                            torch.mean(tmp_batch_loss_stft_l1).unsqueeze(0)))
                        else:
                            batch_loss_stft_l1 = torch.mean(tmp_batch_loss_stft_l1).unsqueeze(0)
                    n_sum += n
                    m = 0
                    for i in range(args.n_fft_facts):
                        if feat_len > int(fft_facts[i]/2):
                            tmp_batch_stft_loss = criterion_lsd(stft_out[i], stft_trg[i])
                            if not torch.isinf(tmp_batch_stft_loss) and not torch.isnan(tmp_batch_stft_loss):
                                if m > 0:
                                    tmp_batch_loss_lsd = torch.cat((tmp_batch_loss_lsd, \
                                                                    tmp_batch_stft_loss.unsqueeze(0)))
                                else:
                                    tmp_batch_loss_lsd = tmp_batch_stft_loss.unsqueeze(0)
                                m += 1
                    if m > 0:
                        if m_sum > 0:
                            batch_loss_lsd = torch.cat((batch_loss_lsd, \
                                                        torch.mean(tmp_batch_loss_lsd).unsqueeze(0)))
                        else:
                            batch_loss_lsd = torch.mean(tmp_batch_loss_lsd).unsqueeze(0)
                    m_sum += m
                batch_loss_laplace = torch.mean(batch_loss_laplace)
                batch_loss = batch_loss_laplace
                if n_sum > 0:
                    batch_loss += torch.mean(batch_loss_stft_l1)
                if m_sum > 0:
                    batch_loss_lsd = torch.mean(batch_loss_lsd)
                batch_loss_err = torch.mean(batch_loss_err)
            else:
                batch_loss_laplace = criterion_laplace(mus, bs, batch_x_float, log_b=log_bs)
                batch_loss = batch_loss_laplace
                eps = torch.empty(mus.shape).cuda().uniform_(-0.4999,0.5)
                batch_output = mus-bs_noclip*eps.sign()*torch.log1p(-2*eps.abs())
                batch_loss_err = torch.mean(torch.abs(batch_output-batch_x_float))
                logging.info("%lf %E %lf %E" % (torch.min(batch_x_float), torch.mean(batch_x_float), \
                                                torch.max(batch_x_float), torch.var(batch_x_float)))
                logging.info("%lf %E %lf %E" % (torch.min(batch_output), torch.mean(batch_output), \
                                                torch.max(batch_output), torch.var(batch_output)))
                n = 0
                for i in range(args.n_fft_facts):
                    if feat_len > int(fft_facts[i]/2):
                        stft_out[i] = torch.stft(batch_output, fft_facts[i], window=hann_win[i])
                        stft_trg[i] = torch.stft(batch_x_float, fft_facts[i], window=hann_win[i])
                        tmp_batch_stft_loss = criterion_lsd(stft_out[i], stft_trg[i], LSD=False, L2=False)
                        if not torch.isinf(tmp_batch_stft_loss) and not torch.isnan(tmp_batch_stft_loss):
                            if n > 0:
                                batch_loss_stft_l1 = torch.cat((batch_loss_stft_l1, \
                                                                tmp_batch_stft_loss.unsqueeze(0)))
                            else:
                                batch_loss_stft_l1 = tmp_batch_stft_loss.unsqueeze(0)
                            n += 1
                if n > 0:
                    batch_loss += torch.mean(batch_loss_stft_l1)
                m = 0
                for i in range(args.n_fft_facts):
                    if feat_len > int(fft_facts[i]/2):
                        tmp_batch_stft_loss = criterion_lsd(stft_out[i], stft_trg[i])
                        if not torch.isinf(tmp_batch_stft_loss) and not torch.isnan(tmp_batch_stft_loss):
                            if m > 0:
                                batch_loss_lsd = torch.cat((batch_loss_lsd, tmp_batch_stft_loss.unsqueeze(0)))
                            else:
                                batch_loss_lsd = tmp_batch_stft_loss.unsqueeze(0)
                            m += 1
                if m > 0:
                    batch_loss_lsd = torch.mean(batch_loss_lsd)

            optimizer.zero_grad()
            batch_loss.backward()
            optimizer.step()

            loss_err.append(batch_loss_err.item())
            loss_laplace.append(batch_loss_laplace.item())
            if (model.seg > 1 and m_sum > 0) or (model.seg == 1 and m > 0):
                loss_lsd.append(batch_loss_lsd.item())
                logging.info("batch loss %s [%d:%d] %d %d %d %d %d %d = %.3f %.3f dB %.6f (%.3f sec)" % (
                    os.path.basename(os.path.dirname(wavfile))+"/"+os.path.basename(wavfile), c_idx+1, \
                    utt_idx+1, tf, ts, h_ss, h_bs, x_ss, x_bs, batch_loss_laplace.item(), \
                    batch_loss_lsd.item(), batch_loss_err.item(), time.time() - start))
            else:
                logging.info("batch loss %s [%d:%d] %d %d %d %d %d %d = %.3f n/a %.6f (%.3f sec)" % (
                    os.path.basename(os.path.dirname(wavfile))+"/"+os.path.basename(wavfile), c_idx+1, \
                    utt_idx+1, tf, ts, h_ss, h_bs, x_ss, x_bs, batch_loss_laplace.item(), \
                    batch_loss_err.item(), time.time() - start))
            iter_idx += 1
            iter_count += 1
            total += time.time() - start

    # save final model
    model.cpu()
    torch.save({"model": model.state_dict()}, args.expdir + "/checkpoint-final.pkl")
    logging.info("final checkpoint created.")
Beispiel #24
0
 def __enter__(self):
     self.old_rng_state = torch.get_rng_state()
     torch.manual_seed(self.rng_seed)
 def record_rng(self, *args):
     self.cpu_state = torch.get_rng_state()
     if torch.cuda._initialized:
         self.cuda_in_fwd = True
         self.gpu_devices, self.gpu_states = get_device_states(*args)
Beispiel #26
0
def random_splits_mask_class(
    dataset,
    num_train_per_class=20,
    num_val_per_class=30,
    num_val=None,
    num_test=None,
    seed=None,
):
    r"""If the data has masks for train/val/test, return the splits with specific number of samples from every class for training as suggested in Pitfalls of graph neural network evaluation [1]_ for semi-supervised learning.

    References
    ----------
    .. [1] Shchur, O., Mumme, M., Bojchevski, A., & Günnemann, S. (2018).
        Pitfalls of graph neural network evaluation.
        arXiv preprint arXiv:1811.05868.

    Parameters
    ----------
    num_train_per_class : int
        the number of samples from every class used for training.

    num_val_per_class : int
        the number of samples from every class used for validation.

    num_val : int
        the total number of nodes that used for validation as alternative.

    num_test : int
        the total number of nodes that used for testing as alternative. The rest of the data will be seleted as test set if num_test set to None.

    seed : int
        random seed for splitting dataset.
    """
    data = dataset[0]

    r_s = torch.get_rng_state()
    if torch.cuda.is_available():
        r_s_cuda = torch.cuda.get_rng_state()
    if seed is not None:
        torch.manual_seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed(seed)

    num_classes = data.y.max().cpu().item() + 1
    try:
        data.train_mask.fill_(False)
        data.val_mask.fill_(False)
        data.test_mask.fill_(False)
    except:
        train_mask = torch.zeros(data.num_nodes,
                                 dtype=torch.bool,
                                 device=data.edge_index.device)
        val_mask = torch.zeros(data.num_nodes,
                               dtype=torch.bool,
                               device=data.edge_index.device)
        test_mask = torch.zeros(data.num_nodes,
                                dtype=torch.bool,
                                device=data.edge_index.device)
        setattr(data, "train_mask", train_mask)
        setattr(data, "val_mask", val_mask)
        setattr(data, "test_mask", test_mask)
    for c_i in range(num_classes):
        idx = (data.y == c_i).nonzero().view(-1)
        assert num_train_per_class + num_val_per_class < idx.size(0), (
            "the total number of samples from every class used for training and validation is larger than the total samples in class "
            + str(c_i))
        idx_idx_rand = torch.randperm(idx.size(0))
        idx_train = idx[idx_idx_rand[:num_train_per_class]]
        idx_val = idx[idx_idx_rand[num_train_per_class:num_train_per_class +
                                   num_val_per_class]]
        data.train_mask[idx_train] = True
        data.val_mask[idx_val] = True

    if num_val is not None:
        remaining = (~data.train_mask).nonzero().view(-1)
        remaining = remaining[torch.randperm(remaining.size(0))]
        data.val_mask[remaining[:num_val]] = True
        if num_test is not None:
            data.test_mask[remaining[num_val:num_val + num_test]] = True
        else:
            data.test_mask[remaining[num_val:]] = True
    else:
        remaining = (~(data.train_mask + data.val_mask)).nonzero().view(-1)
        data.test_mask[remaining] = True

    torch.set_rng_state(r_s)
    if torch.cuda.is_available():
        torch.cuda.set_rng_state(r_s_cuda)

    dataset.data, dataset.slices = dataset.collate([d for d in dataset])
    # while type(dataset.data.num_nodes) == list:
    #     dataset.data.num_nodes = dataset.data.num_nodes[0]
    # dataset.data.num_nodes = dataset.data.num_nodes[0]
    return dataset
##### load data #####

graph_list = pickle.load(open(processed + dataset_name + 'networkx_graphs.pkl', 'rb'))
data_len = len(graph_list)

##### training parameters #####
epoch_samples = [k for k in [0, 24, 49, 75, 99, 124, 149, 174, 199] if k < max_epoch]
bs = {'DHFR/': 11, 'MUTAG/': 10, 'COX2/': 9, 'IMDB-BINARY/': 18, 'NCI1/': 20, 'IMDB-MULTI/': 27 }

batch_size = bs[dataset_name]
test_size =  data_len // 10
train_batches = np.ceil((data_len-test_size)/batch_size).astype(int)
print('number of batches = ', train_batches, ' batch size = ', batch_size)

torch.manual_seed(0)
rng_state= torch.get_rng_state() #seed init to ensure same initial conditions for each training

### eigenvalue path signature ####
if xtra_feat:
    pslevel = 4
    sig_prep = iisignature.prepare(2, pslevel)
    #xtra_feat_length = iisignature.logsiglength(2, pslevel)
    siglength = iisignature.logsiglength(2, pslevel)
    xtra_feat_length = siglength
    if xxtra: siglength += 4
else:
    xtra_feat_length = 0

print(xtra_feat_length)

Beispiel #28
0
    def train(self, net, optimizer, resume=False, scheduler=None):
        logging.info('=' * 80)
        logging.info('Start training')
        self.log_datetime()
        logging.info('=' * 80)

        train_set = self.get_train_set()
        test_sets = self.get_test_sets()

        net = net.to(self.train_device)

        epoch = 0
        min_err = {ts.name: 1e9 for ts in test_sets}

        state_path = self.exp_out_root / 'state.dict'
        if resume and state_path.exists():
            logging.info('=' * 80)
            logging.info(f'Loading state from {state_path}')
            logging.info('=' * 80)
            state = torch.load(str(state_path))
            epoch = state['epoch'] + 1
            if 'min_err' in state:
                min_err = state['min_err']

            curr_state = net.state_dict()
            curr_state.update(state['state_dict'])
            net.load_state_dict(curr_state)

            try:
                optimizer.load_state_dict(state['optimizer'])
            except:
                logging.info('Warning: cannot load optimizer from state_dict')
                pass
            if 'cpu_rng_state' in state:
                torch.set_rng_state(state['cpu_rng_state'])
            if 'gpu_rng_state' in state:
                torch.cuda.set_rng_state(state['gpu_rng_state'])

        for epoch in range(epoch, self.epochs):
            self.callback_train_new_epoch(epoch, net, optimizer)

            # train epoch
            self.train_epoch(epoch, net, optimizer, train_set)

            # test epoch
            errs = self.test(epoch, net, test_sets)

            if (epoch + 1) % self.save_frequency == 0:
                net = net.to(self.train_device)

                # store state
                state_dict = {
                    'epoch': epoch,
                    'min_err': min_err,
                    'state_dict': net.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'cpu_rng_state': torch.get_rng_state(),
                    'gpu_rng_state': torch.cuda.get_rng_state(),
                }
                logging.info(f'save state to {state_path}')
                state_path = self.exp_out_root / 'state.dict'
                torch.save(state_dict, str(state_path))

                for test_set_name in errs:
                    err = sum(errs[test_set_name])
                    if err < min_err[test_set_name]:
                        min_err[test_set_name] = err
                        state_path = self.exp_out_root / f'state_set{test_set_name}_best.dict'
                        logging.info(f'save state to {state_path}')
                        torch.save(state_dict, str(state_path))

                # store network
                net_path = self.get_net_path(epoch)
                logging.info(f'save network to {net_path}')
                torch.save(net.state_dict(), str(net_path))

            if scheduler is not None:
                scheduler.step()

        logging.info('=' * 80)
        logging.info('Finished training')
        self.log_datetime()
        logging.info('=' * 80)
Beispiel #29
0
def run(opts):
    # start time
    start_time = time()
    train_run = []
    opts.save_hrs.sort()
    run_name = opts.run_name

    # Pretty print the run args
    pp.pprint(vars(opts))

    # Set the random seed
    torch.manual_seed(opts.seed)

    # Optionally configure tensorboard
    tb_logger = None
    if not opts.no_tensorboard:
        tb_logger = TbLogger(
            os.path.join(opts.log_dir, "{}_{}".format(opts.problem,
                                                      opts.graph_size),
                         opts.run_name))

    os.makedirs(opts.save_dir)
    # Save arguments so exact configuration can always be found
    with open(os.path.join(opts.save_dir, "args.json"), 'w') as f:
        json.dump(vars(opts), f, indent=True)

    # Set the device
    opts.device = torch.device("cuda:0" if opts.use_cuda else "cpu")

    # Figure out what's the problem
    problem = load_problem(opts.problem)

    # Load data from load_path
    load_data = {}
    assert opts.load_path is None or opts.resume is None, "Only one of load path and resume can be given"
    load_path = opts.load_path if opts.load_path is not None else opts.resume
    if load_path is not None:
        print('  [*] Loading data from {}'.format(load_path))
        load_data = torch_load_cpu(load_path)

    # hyperparameter search
    # default (user specified) config
    config_defaults = {
        'batch_size': opts.batch_size,
        'lr_model': opts.lr_model,
        'lr_critic': opts.lr_critic,
        'lr_decay': opts.lr_decay,
    }

    # determine the parameter space
    """sweep_config = {
        'parameters': {
            'batch_size': {
                'values': [256, 128, 64, 32]
            },
            'lr_model': {
                'values': [1e-2, 1e-3, 1e-4, 3e-4, 3e-5, 1e-5]
            },
            'lr_critic': {
                'values': [1e-2, 1e-3, 1e-4, 3e-4, 3e-5, 1e-5]
            },
            'lr_decay': {
                'lr_decay': [0.9, 0.95, 1.0, 1.05, 1.1, 1.15]
            },
        }
    }"""
    # initialize the sweep
    # sweep_id = wandb.sweep(sweep_config, project="Pytorch-sweeps")

    # Initialize a new wandb run
    wandb.init(config=config_defaults)

    # Config is a variable that holds and saves hyperparameters and inputs
    config = wandb.config

    # ??? any code for setting up hyperparameters interested should use config.parameter to set instead of opt.parameter
    # including functions in other files-> pass config to other functions

    # Initialize model
    model_class = {
        'attention': AttentionModel,
        'pointer': PointerNetwork
    }.get(opts.model, None)
    assert model_class is not None, "Unknown model: {}".format(model_class)
    model = model_class(opts.embedding_dim,
                        opts.hidden_dim,
                        problem,
                        n_encode_layers=opts.n_encode_layers,
                        mask_inner=True,
                        mask_logits=True,
                        normalization=opts.normalization,
                        tanh_clipping=opts.tanh_clipping,
                        checkpoint_encoder=opts.checkpoint_encoder,
                        shrink_size=opts.shrink_size).to(opts.device)

    if opts.use_cuda and torch.cuda.device_count() > 1:
        model = torch.nn.DataParallel(model)

    # Overwrite model parameters by parameters to load
    model_ = get_inner_model(model)
    model_.load_state_dict({
        **model_.state_dict(),
        **load_data.get('model', {})
    })

    # Initialize baseline
    if opts.baseline == 'exponential':
        baseline = ExponentialBaseline(opts.exp_beta)
    elif opts.baseline == 'critic' or opts.baseline == 'critic_lstm':
        assert problem.NAME == 'tsp', "Critic only supported for TSP"
        baseline = CriticBaseline(
            (CriticNetworkLSTM(2, opts.embedding_dim, opts.hidden_dim,
                               opts.n_encode_layers, opts.tanh_clipping)
             if opts.baseline == 'critic_lstm' else CriticNetwork(
                 2, opts.embedding_dim, opts.hidden_dim, opts.n_encode_layers,
                 opts.normalization)).to(opts.device))
    elif opts.baseline == 'rollout':
        baseline = RolloutBaseline(model, problem, opts)
    else:
        assert opts.baseline is None, "Unknown baseline: {}".format(
            opts.baseline)
        baseline = NoBaseline()

    if opts.bl_warmup_epochs > 0:
        baseline = WarmupBaseline(baseline,
                                  opts.bl_warmup_epochs,
                                  warmup_exp_beta=opts.exp_beta)

    # Load baseline from data, make sure script is called with same type of baseline
    if 'baseline' in load_data:
        baseline.load_state_dict(load_data['baseline'])

    # Initialize optimizer
    optimizer = optim.Adam([{
        'params': model.parameters(),
        'lr': config.lr_model
    }] + ([{
        'params': baseline.get_learnable_parameters(),
        'lr': config.lr_critic
    }] if len(baseline.get_learnable_parameters()) > 0 else []))

    # Load optimizer state
    if 'optimizer' in load_data:
        optimizer.load_state_dict(load_data['optimizer'])
        for state in optimizer.state.values():
            for k, v in state.items():
                # if isinstance(v, torch.Tensor):
                if torch.is_tensor(v):
                    state[k] = v.to(opts.device)

    # Initialize learning rate scheduler, decay by lr_decay once per epoch!
    lr_scheduler = optim.lr_scheduler.LambdaLR(
        optimizer, lambda epoch: config.lr_decay**epoch)

    # Start the actual training loop
    val_dataset = problem.make_dataset(size=opts.graph_size,
                                       num_samples=opts.val_size,
                                       filename=opts.val_dataset,
                                       distribution=opts.data_distribution)

    if opts.resume:
        epoch_resume = int(
            os.path.splitext(os.path.split(opts.resume)[-1])[0].split("-")[1])

        torch.set_rng_state(load_data['rng_state'])
        if opts.use_cuda:
            torch.cuda.set_rng_state_all(load_data['cuda_rng_state'])
        # Set the random states
        # Dumping of state was done before epoch callback, so do that now (model is loaded)
        baseline.epoch_callback(model, epoch_resume)
        print("Resuming after {}".format(epoch_resume))
        opts.epoch_start = epoch_resume + 1

    torch.save(model, os.path.join('.', 'empty.pt'))
    if opts.eval_only:
        validate(model, val_dataset, opts)
    else:
        for epoch in range(opts.epoch_start, opts.epoch_start + opts.n_epochs):
            avg_time = train_epoch(model, optimizer, baseline, lr_scheduler,
                                   epoch, val_dataset, problem, tb_logger,
                                   opts, start_time, config)
            train_run.append(avg_time)
            for hr in opts.save_hrs:
                if (time() - start_time) > hr * 3600:
                    opts.save_hrs.remove(hr)
                    print('Saving model and state...')
                    hr_time = int(round((time() - start_time) / 3600))
                    with open(
                            '../models/att/hist_{}_{}hr.pickle'.format(
                                run_name, hr_time), 'wb') as handle:
                        pickle.dump(train_run,
                                    handle,
                                    protocol=pickle.HIGHEST_PROTOCOL)
                    torch.save(
                        {
                            'model': get_inner_model(model).state_dict(),
                            'optimizer': optimizer.state_dict(),
                            'rng_state': torch.get_rng_state(),
                            'cuda_rng_state': torch.cuda.get_rng_state_all(),
                            'baseline': baseline.state_dict()
                        },
                        os.path.join(
                            '../models/att',
                            '{}_{}hr-model-att-only.pt'.format(
                                run_name, hr_time)))
                    torch.save(
                        model,
                        os.path.join(
                            '../models/att',
                            '{}_{}hr-model.pt'.format(run_name, hr_time)))
Beispiel #30
0
def test__set_rng_states_cuda():
    # Checks https://github.com/pytorch/ignite/issues/2076

    rng_states = [random.getstate(), torch.get_rng_state().cuda(), np.random.get_state()]
    _set_rng_states(rng_states)
    assert rng_states[1].device.type == "cpu"
Beispiel #31
0
        if switch_to_rl and not best:
            data = torch.load(
                (args.save_folder + '/%s_best.pth') % args.exp_name)
            torch.set_rng_state(data['torch_rng_state'])
            torch.cuda.set_rng_state(data['cuda_rng_state'])
            np.random.set_state(data['numpy_rng_state'])
            random.setstate(data['random_rng_state'])
            model.load_state_dict(data['state_dict'])
            print(
                'Resuming from epoch %d, validation loss %f, and best cider %f'
                % (data['epoch'], data['val_loss'], data['best_cider']))

        torch.save(
            {
                'torch_rng_state': torch.get_rng_state(),
                'cuda_rng_state': torch.cuda.get_rng_state(),
                'numpy_rng_state': np.random.get_state(),
                'random_rng_state': random.getstate(),
                'epoch': e,
                'val_loss': val_loss,
                'val_cider': val_cider,
                'state_dict': model.state_dict(),
                'optimizer': optim.state_dict(),
                'scheduler': scheduler.state_dict(),
                'patience': patience,
                'best_cider': best_cider,
                'use_rl': use_rl,
            }, (args.save_folder + '/%s_last.pth') % args.exp_name)

        if best:
    if opt.saveBest:
        savePath = os.path.join(dirPath, "model_epoch_best.pth")
    else:
        savePath = os.path.join(dirPath,
                                "model_epoch_{}.pth".format(state['epoch']))
    th.save(state, savePath)
    print("===> Checkpoint saved to {}".format(savePath))


epoch_loss = float('inf')
for epoch in range(start + 1, opt.nEpochs + 1):
    if opt.saveBest:
        epoch_loss_new = train(epoch)
    else:
        train(epoch)
    test()
    if not epoch % opt.saveFreq:
        if opt.saveBest and epoch_loss_new < epoch_loss:
            epoch_loss = epoch_loss_new
            state = {'epoch':epoch, 'model_state_dict':model.state_dict(),\
                 'optimizer_state_dict':optimizer.state_dict(),\
                 'rng_state':th.get_rng_state(),'params':params}
            save_checkpoint(state)
        elif not opt.saveBest:
            state = {'epoch':epoch, 'model_state_dict':model.state_dict(),\
                 'optimizer_state_dict':optimizer.state_dict(),\
                 'rng_state':th.get_rng_state(),'params':params}
            save_checkpoint(state)
    print("******************************************************")

print("\n ============ Training completed ======================\n")
 def setUp(self):
     if os.getenv('UNLOCK_SEED') is None or os.getenv(
             'UNLOCK_SEED').lower() == 'false':
         self.rng_state = torch.get_rng_state()
         torch.manual_seed(0)
Beispiel #34
0
today = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
# Append date and time of the run to the directory, to avoid several runs of
# overwritting each other.
saveDir = saveDir + today
# Create directory
if not os.path.exists(saveDir):
    os.makedirs(saveDir)
# Create the file where all the (hyper)parameters and results will be saved.
varsFile = os.path.join(saveDir, 'hyperparameters.txt')
with open(varsFile, 'w+') as file:
    file.write('%s\n\n' %
               datetime.datetime.now().strftime("%Y/%m/%d %H:%M:%S"))

# \\\ Save seeds for reproducibility
#   PyTorch seeds
torchState = torch.get_rng_state()
torchSeed = torch.initial_seed()
#   Numpy seeds
numpyState = np.random.RandomState().get_state()
#   Collect all random states
randomStates = []
randomStates.append({})
randomStates[0]['module'] = 'numpy'
randomStates[0]['state'] = numpyState
randomStates.append({})
randomStates[1]['module'] = 'torch'
randomStates[1]['state'] = torchState
randomStates[1]['seed'] = torchSeed
#   This list and dictionary follows the format to then be loaded, if needed,
#   by calling the loadSeed function in Utils.miscTools
saveSeed(randomStates, saveDir)
Beispiel #35
0
def main():
    parser = argparse.ArgumentParser()
    # path setting
    parser.add_argument("--waveforms", required=True,
                        type=str, help="directory or list of wav files")
    parser.add_argument("--waveforms_eval", required=True,
                        type=str, help="directory or list of evaluation wav files")
    parser.add_argument("--feats", required=True,
                        type=str, help="directory or list of aux feat files")
    parser.add_argument("--feats_eval", required=True,
                        type=str, help="directory or list of evaluation aux feat files")
    parser.add_argument("--config", required=True,
                        type=str, help="configure file")
    parser.add_argument("--pretrained", required=True,
                        type=str, help="pretrained model path")
    parser.add_argument("--expdir", required=True,
                        type=str, help="directory to save the model")
    parser.add_argument("--init_acc", default=False,
                        type=strtobool, help="flag to compute accuracy of initial pretrained model")
    parser.add_argument("--string_path", required=True,
                        type=str, help="flag to compute accuracy of initial pretrained model")
    # other setting
    parser.add_argument("--seed", default=1,
                        type=int, help="seed number")
    parser.add_argument("--resume", default=None,
                        type=str, help="model path to restart training")
    parser.add_argument("--GPU_device", default=None,
                        type=int, help="selection of GPU device")
    parser.add_argument("--verbose", default=1,
                        type=int, help="log level")
    args = parser.parse_args()

    if args.GPU_device is not None:
        os.environ["CUDA_DEVICE_ORDER"]     = "PCI_BUS_ID"
        os.environ["CUDA_VISIBLE_DEVICES"]  = str(args.GPU_device)

    # make experimental directory
    if not os.path.exists(args.expdir):
        os.makedirs(args.expdir)

    # set log level
    if args.verbose == 1:
        logging.basicConfig(level=logging.INFO,
                            format='%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s',
                            datefmt='%m/%d/%Y %I:%M:%S',
                            filename=args.expdir + "/train.log")
        logging.getLogger().addHandler(logging.StreamHandler())
    elif args.verbose > 1:
        logging.basicConfig(level=logging.DEBUG,
                            format='%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s',
                            datefmt='%m/%d/%Y %I:%M:%S',
                            filename=args.expdir + "/train.log")
        logging.getLogger().addHandler(logging.StreamHandler())
    else:
        logging.basicConfig(level=logging.WARN,
                            format='%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s',
                            datefmt='%m/%d/%Y %I:%M:%S',
                            filename=args.expdir + "/train.log")
        logging.getLogger().addHandler(logging.StreamHandler())
        logging.warn("logging is disabled.")

    # fix seed
    os.environ['PYTHONHASHSEED'] = str(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    torch.backends.cudnn.benchmark = True #faster
    #torch.backends.cudnn.deterministic = True #reproducibility_slower
    #torch.backends.cudnn.benchmark = False #reproducibility_slower

    # load config
    config = torch.load(args.config)
    config.expdir = args.expdir
    config.pretrained = args.pretrained
    config.string_path_ft = args.string_path

    # save args as conf
    if not args.init_acc:
        torch.save(config, args.expdir + "/model.conf")

    # # define network
    model = DSWNV(
        n_quantize=config.n_quantize,
        n_aux=config.n_aux,
        hid_chn=config.hid_chn,
        skip_chn=config.skip_chn,
        dilation_depth=config.dilation_depth,
        dilation_repeat=config.dilation_repeat,
        kernel_size=config.kernel_size,
        aux_kernel_size=config.aux_kernel_size,
        aux_dilation_size=config.aux_dilation_size,
        do_prob=config.do_prob,
        upsampling_factor=config.upsampling_factor)
    logging.info(model)
    criterion = nn.CrossEntropyLoss()

    checkpoint = torch.load(args.pretrained)
    model.load_state_dict(checkpoint["model"])
    epoch_idx = checkpoint["iterations"]
    logging.info("pretrained from %d-iter checkpoint." % epoch_idx)
    epoch_idx = 0

    # send to gpu
    if torch.cuda.is_available():
        model.cuda()
        criterion.cuda()
    else:
        logging.error("gpu is not available. please check the setting.")
        sys.exit(1)

    if not args.init_acc:
        for param in model.parameters():
            param.requires_grad = True
        for param in model.scale_in.parameters():
            param.requires_grad = False
        parameters = filter(lambda p: p.requires_grad, model.parameters())
        parameters = sum([np.prod(p.size()) for p in parameters]) / 1000000
        logging.info('Trainable Parameters: %.3f million' % parameters)
        module_list = list(model.conv_aux.parameters())+ list(model.upsampling.parameters())
        module_list += list(model.causal.parameters()) + list(model.in_x.parameters())
        module_list += list(model.dil_h.parameters()) + list(model.out_skip.parameters())
        module_list += list(model.out_1.parameters()) + list(model.out_2.parameters())
        optimizer = torch.optim.Adam(module_list, lr=config.lr)
        model.train()
    else:
        for param in model.parameters():
            param.requires_grad = False
        model.eval()

    # resume
    if args.resume is not None:
        checkpoint = torch.load(args.resume)
        model.load_state_dict(checkpoint["model"])
        optimizer.load_state_dict(checkpoint["optimizer"])
        epoch_idx = checkpoint["iterations"]
        logging.info("restored from %d-iter checkpoint." % epoch_idx)

    wav_transform = transforms.Compose([lambda x: encode_mu_law(x, config.n_quantize)])

    # define generator training
    if os.path.isdir(args.waveforms):
        filenames = sorted(find_files(args.waveforms, "*.wav", use_dir_name=False))
        wav_list = [args.waveforms + "/" + filename for filename in filenames]
        feat_list = [args.feats + "/" + filename.replace(".wav", ".h5") for filename in filenames]
    elif os.path.isfile(args.waveforms):
        wav_list = read_txt(args.waveforms)
        feat_list = read_txt(args.feats)
    else:
        logging.error("--waveforms should be directory or list.")
        sys.exit(1)
    assert len(wav_list) == len(feat_list)
    logging.info("number of training data = %d." % len(wav_list))

    generator = train_generator(
        wav_list, feat_list,
        receptive_field=model.receptive_field,
        batch_size=config.batch_size,
        wav_transform=wav_transform,
        string_path=config.string_path_ft,
        training=True,
        upsampling_factor=config.upsampling_factor)

    # define generator evaluation
    if os.path.isdir(args.waveforms_eval):
        filenames_eval = sorted(find_files(args.waveforms_eval, "*.wav", use_dir_name=False))
        wav_list_eval = [args.waveforms_eval + "/" + filename for filename in filenames_eval]
        feat_list_eval = [args.feats_eval + "/" + filename.replace(".wav", ".h5") \
                            for filename in filenames_eval]
    elif os.path.isfile(args.waveforms_eval):
        wav_list_eval = read_txt(args.waveforms_eval)
        feat_list_eval = read_txt(args.feats_eval)
    else:
        logging.error("--waveforms_eval should be directory or list.")
        sys.exit(1)
    logging.info("number of evaluation data = %d." % len(wav_list_eval))
    assert len(wav_list_eval) == len(feat_list_eval)
    generator_eval = train_generator(
        wav_list_eval, feat_list_eval,
        receptive_field=model.receptive_field,
        batch_size=config.batch_size,
        wav_transform=wav_transform,
        string_path=config.string_path_ft,
        training=False,
        upsampling_factor=config.upsampling_factor)

    if args.init_acc:
        epoch_idx = -1

    # train
    loss = []
    total = 0
    iter_idx = 0
    iter_count = 0
    min_eval_loss = 99999999.99
    min_eval_loss_std = 99999999.99
    if args.resume is not None:
        np.random.set_state(checkpoint["numpy_random_state"])
        torch.set_rng_state(checkpoint["torch_random_state"])
    logging.info("==%d EPOCH==" % (epoch_idx+1))
    logging.info("Training data")
    while epoch_idx < config.epoch_count:
        start = time.time()
        batch_x_class, batch_x, batch_h, c_idx, utt_idx, wavfile, h_bs, x_bs, h_ss, x_ss = next(generator)
        if c_idx < 0: # summarize epoch
            # save current epoch model
            if not args.init_acc:
                numpy_random_state = np.random.get_state()
                torch_random_state = torch.get_rng_state()
                save_checkpoint(args.expdir, model, optimizer, numpy_random_state, \
                                torch_random_state, epoch_idx + 1)
            # report current epoch
            logging.info("(EPOCH:%d) average training loss = %.6f (+- %.6f) (%.3f min., %.3f sec / batch)" % (
                epoch_idx + 1, np.mean(np.array(loss, dtype=np.float64)), \
                np.std(np.array(loss, dtype=np.float64)), total / 60.0, total / iter_count))
            logging.info("estimated training required time = {0.days:02}:{0.hours:02}:{0.minutes:02}:"\
            "{0.seconds:02}".format(relativedelta(seconds=int((config.epoch_count - (epoch_idx + 1))*total))))
            # compute loss in evaluation data
            loss = []
            total = 0
            iter_count = 0
            if not args.init_acc:
                model.eval()
                for param in model.parameters():
                    param.requires_grad = False
            logging.info("Evaluation data")
            while True:
                start = time.time()
                batch_x_class, batch_x, batch_h, c_idx, utt_idx, wavfile, h_bs, x_bs, h_ss, x_ss \
                    = next(generator_eval)
                if c_idx < 0:
                    break

                tf = batch_h.shape[0]
                ts = batch_x.shape[0]

                batch_h = batch_h[h_ss:]
                batch_x_class = batch_x_class[x_ss:]
                batch_x = batch_x[x_ss:]
                if h_bs != -1:
                    batch_h = batch_h[:h_bs]
                    batch_x_class = batch_x_class[1:x_bs]
                    batch_x = batch_x[:x_bs-1]
                else:
                    batch_x = batch_x[:-1]
                    batch_x_class = batch_x_class[1:]
                batch_h = batch_h.transpose(0,1).unsqueeze(0)
                batch_x = batch_x.transpose(0,1).unsqueeze(0)

                batch_output = model(batch_x, batch_h)[0]

                if h_ss > 0:
                    batch_loss = criterion(batch_output[model.receptive_field:], \
                                            batch_x_class[model.receptive_field:])
                else:
                    batch_loss = criterion(batch_output, batch_x_class)

                loss.append(batch_loss.item())
                logging.info("batch eval loss %s [%d:%d] %d %d %d %d %d %d = %.3f (%.3f sec)" % (
                    os.path.basename(os.path.dirname(wavfile))+"/"+os.path.basename(wavfile), \
                    c_idx+1, utt_idx+1, tf, ts, h_ss, h_bs, x_ss, x_bs, batch_loss.item(), time.time()-start))
                iter_count += 1
                total += time.time() - start
            eval_loss = np.mean(np.array(loss, dtype=np.float64))
            eval_loss_std = np.std(np.array(loss, dtype=np.float64))
            logging.info("(EPOCH:%d) average evaluation loss = %.6f (+- %.6f) (%.3f min., %.3f sec / batch)"%(
                epoch_idx + 1, eval_loss, eval_loss_std, total / 60.0, total / iter_count))
            if not args.init_acc:
                if (eval_loss+eval_loss_std) <= (min_eval_loss+min_eval_loss_std):
                    min_eval_loss = eval_loss
                    min_eval_loss_std = eval_loss_std
                    min_idx = epoch_idx
                logging.info("min_eval_loss=%.6f (+- %.6f), min_idx=%d" % (\
                                min_eval_loss, min_eval_loss_std, min_idx+1))
                loss = []
                total = 0
                iter_count = 0
                epoch_idx += 1
                np.random.set_state(numpy_random_state)
                torch.set_rng_state(torch_random_state)
                model.train()
                for param in model.parameters():
                    param.requires_grad = True
                for param in model.scale_in.parameters():
                    param.requires_grad = False
            else:
                exit()
            # start next epoch
            if epoch_idx < config.epoch_count:
                start = time.time()
                logging.info("==%d EPOCH==" % (epoch_idx+1))
                logging.info("Training data")
                batch_x_class, batch_x, batch_h, c_idx, utt_idx, wavfile, h_bs, x_bs, h_ss, x_ss \
                        = next(generator)
        # feedforward and backpropagate current batch
        if epoch_idx < config.epoch_count:
            logging.info("%d iteration [%d]" % (iter_idx+1, epoch_idx+1))

            tf = batch_h.shape[0]
            ts = batch_x.shape[0]

            batch_h = batch_h[h_ss:]
            batch_x_class = batch_x_class[x_ss:]
            batch_x = batch_x[x_ss:]
            if h_bs != -1:
                batch_h = batch_h[:h_bs]
                batch_x_class = batch_x_class[1:x_bs]
                batch_x = batch_x[:x_bs-1]
            else:
                batch_x = batch_x[:-1]
                batch_x_class = batch_x_class[1:]
            batch_h = batch_h.transpose(0,1).unsqueeze(0)
            batch_x = batch_x.transpose(0,1).unsqueeze(0)

            if not args.init_acc:
                batch_output = model(batch_x, batch_h, do=True)[0]
            else:
                batch_output = model(batch_x, batch_h)[0]

            if h_ss > 0:
                batch_loss = criterion(batch_output[model.receptive_field:], \
                                        batch_x_class[model.receptive_field:])
            else:
                batch_loss = criterion(batch_output, batch_x_class)

            if not args.init_acc:
                optimizer.zero_grad()
                batch_loss.backward()
                optimizer.step()

            loss.append(batch_loss.item())
            logging.info("batch loss %s [%d:%d] %d %d %d %d %d %d = %.3f (%.3f sec)" % (
                os.path.basename(os.path.dirname(wavfile))+"/"+os.path.basename(wavfile), c_idx+1, utt_idx+1,
                    tf, ts, h_ss, h_bs, x_ss, x_bs, batch_loss.item(), time.time() - start))
            iter_idx += 1
            iter_count += 1
            total += time.time() - start

    # save final model
    model.cpu()
    torch.save({"model": model.state_dict()}, args.expdir + "/checkpoint-final.pkl")
    logging.info("final checkpoint created.")
Beispiel #36
0
        optimiser.zero_grad()

        if args.model in ['ANP', 'NP']:
            y_target_mu, y_target_sigma, loss, log_pred, kl_target_context  = model.forward(
                x_context, y_context, x_target, y_target
            )
        else:
            y_target_mu, y_target_sigma, loss, _, _ = model.forward(x_context, y_context, x_target, y_target)
            loss = loss.sum()
        loss.backward()
        optimiser.step()

        if epoch % args.test_epoch == 0:
            # plot some samples
            # refix the seed
            random_state = torch.get_rng_state()
            for i in range(10):
                torch.manual_seed(i)
                # Sample a single GP from the distribution
                data_test, params = datagen_test.generate_curves()
                y_context = data_test.query[0][1].contiguous()[0].unsqueeze(0)
                x_context = data_test.query[0][0].contiguous()[0].unsqueeze(0)
                x_target = data_test.query[1].contiguous()[0].unsqueeze(0)
                y_target = data_test.target_y.contiguous()[0].unsqueeze(0)

                # If our model has a latent part, we want to plot many 
                # samples from the function
                if args.model in ['ANP', 'NP']:
                    y_target_mu = []
                    y_target_sigma = []
                    for j in range(10):