def freeze_rng_state(): rng_state = torch.get_rng_state() if torch.cuda.is_available(): cuda_rng_state = torch.cuda.get_rng_state() yield if torch.cuda.is_available(): torch.cuda.set_rng_state(cuda_rng_state) torch.set_rng_state(rng_state)
def forward(self,seed): #Store rng rng_cpu=torch.get_rng_state(); rng_gpu=torch.cuda.get_rng_state(); torch.manual_seed(seed); torch.cuda.manual_seed(seed); mask=[]; for param in self.params: mask.append(Variable(param.data.clone().bernoulli_(self.p))); #Recover rng torch.set_rng_state(rng_cpu); torch.cuda.set_rng_state(rng_gpu); #Compute output out=[]; for i,param in enumerate(self.params): out.append(mask[i]*param); return out;
def noise(self,seed=None): if seed is None: eps=[]; for param in self.mean: eps.append(Variable(param.data.clone().normal_(0,1))); return eps; else: rng_cpu=torch.get_rng_state(); rng_gpu=torch.cuda.get_rng_state(); torch.manual_seed(seed); torch.cuda.manual_seed(seed); #generate noise eps=[]; for param in self.mean: eps.append(Variable(param.data.clone().normal_(0,1))); #Recover rng torch.set_rng_state(rng_cpu); torch.cuda.set_rng_state(rng_gpu); return eps;
def checkpoint(model, best_loss, epoch, LR): """ Saves checkpoint of torchvision model during training. Args: model: torchvision model to be saved best_loss: best val loss achieved so far in training epoch: current epoch of training LR: current learning rate in training Returns: None """ print('saving') state = { 'model': model, 'best_loss': best_loss, 'epoch': epoch, 'rng_state': torch.get_rng_state(), 'LR': LR } torch.save(state, 'results/checkpoint')
def __getitem__(self, index): """ Args: index (int): Index Returns: tuple: (image, target) where target is class_index of the target class. """ # create random image that is consistent with the index id rng_state = torch.get_rng_state() torch.manual_seed(index + self.random_offset) img = torch.randn(*self.image_size) target = torch.Tensor(1).random_(0, self.num_classes)[0] torch.set_rng_state(rng_state) # convert to PIL Image img = transforms.ToPILImage()(img) if self.transform is not None: img = self.transform(img) if self.target_transform is not None: target = self.target_transform(target) return img, target
def fork_rng(devices=None, enabled=True, _caller="fork_rng", _devices_kw="devices"): """ Forks the RNG, so that when you return, the RNG is reset to the state that it was previously in. Arguments: devices (iterable of CUDA IDs): CUDA devices for which to fork the RNG. CPU RNG state is always forked. By default, fork_rng operates on all devices, but will emit a warning if your machine has a lot of devices, since this function will run very slowly in that case. If you explicitly specify devices, this warning will be supressed enabled (bool): if ``False``, the RNG is not forked. This is a convenience argument for easily disabling the context manager without having to reindent your Python code. """ import torch.cuda global _fork_rng_warned_already # Internal arguments: # _caller: the function which called fork_rng, which the user used # _devices_kw: the devices keyword of _caller if not enabled: yield return if devices is None: num_devices = torch.cuda.device_count() if num_devices > 1 and not _fork_rng_warned_already: warnings.warn( ("CUDA reports that you have {num_devices} available devices, and you " "have used {caller} without explicitly specifying which devices are being used. " "For safety, we initialize *every* CUDA device by default, which " "can be quite slow if you have a lot of GPUs. If you know that you are only " "making use of a few CUDA devices, set the environment variable CUDA_VISIBLE_DEVICES " "or the '{devices_kw}' keyword argument of {caller} with the set of devices " "you are actually using. For example, if you are using CPU only, " "set CUDA_VISIBLE_DEVICES= or devices=[]; if you are using " "GPU 0 only, set CUDA_VISIBLE_DEVICES=0 or devices=[0]. To initialize " "all devices and suppress this warning, set the '{devices_kw}' keyword argument " "to `range(torch.cuda.device_count())`." ).format(num_devices=num_devices, caller=_caller, devices_kw=_devices_kw)) _fork_rng_warned_already = True devices = list(range(num_devices)) else: # Protect against user passing us a generator; we need to traverse this # multiple times but a generator will be exhausted upon first traversal devices = list(devices) cpu_rng_state = torch.get_rng_state() gpu_rng_states = [] for device in devices: with torch.cuda.device(device): gpu_rng_states.append(torch.cuda.get_rng_state()) try: yield finally: torch.set_rng_state(cpu_rng_state) for device, gpu_rng_state in zip(devices, gpu_rng_states): with torch.cuda.device(device): torch.cuda.set_rng_state(gpu_rng_state)
def backward(ctx, *grads): global timers see_memory_usage("In backward", force=False) # removing pointers to the contiguous buffer memory # so that they can be garbage collected once the checkpoints # have been used if SYNCHRONIZE: torch.cuda.synchronize() if PROFILE_TIME: timers('backward').start() if CONTIGUOUS_CHECKPOINTING: global data_offsets, size_offsets global contiguous_data_buffers, contiguous_size_buffers for buffers in contiguous_data_buffers: buffers = [] # frees up all the pointers to the checkpoints except for the ones # stored by save for backward contiguous_data_buffers = [] contiguous_size_buffers = [] data_offsets = [] size_offsets = [] see_memory_usage("In backward checkpointing code", force=False) if not torch.autograd._is_checkpoint_valid(): raise RuntimeError("Checkpointing is not compatible with .grad(), " "please use .backward() if possible") global cuda_device, transport_stream, PARTITION_ACTIVATIONS if PARTITION_ACTIVATIONS: if ctx.saved_tensors and ctx.saved_tensors[0].dtype == torch.bfloat16: FP32_COMM = True else: FP32_COMM = False # with torch.cuda.stream(transport_stream): inputs = get_full_inputs(ctx.saved_tensors, device=cuda_device if PA_TO_CPU else None, fp32_comm=FP32_COMM) detached_inputs = detach_variable(inputs) else: inputs = ctx.saved_tensors detached_inputs = detach_variable(inputs) # Add non tensor input args detached_inputs = merge_tensors(tensor_objects=detached_inputs, non_tensor_objects=ctx.non_tensor_args, tensor_flags=ctx.tensor_flags) # Store the current states. bwd_cpu_rng_state = torch.get_rng_state() bwd_cuda_rng_state = torch.cuda.get_rng_state() bwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states() # Set the states to what it used to be before the forward pass. torch.set_rng_state(ctx.fwd_cpu_rng_state) _set_cuda_rng_state(ctx.fwd_cuda_rng_state) get_cuda_rng_tracker().set_states(ctx.fwd_cuda_rng_state_tracker) # if PARTITION_ACTIVATIONS: # current_stream=torch.cuda.current_stream() # current_stream.wait_stream(transport_stream) see_memory_usage("In backward checkpointing code before forward", force=False) with torch.enable_grad(): outputs = ctx.run_function(*detached_inputs) see_memory_usage("In backward checkpointing code after forward", force=False) # Set the states back to what it was at the start of this function. torch.set_rng_state(bwd_cpu_rng_state) _set_cuda_rng_state(bwd_cuda_rng_state) get_cuda_rng_tracker().set_states(bwd_cuda_rng_state_tracker) if isinstance(outputs, torch.Tensor): outputs = (outputs, ) # Filter out non tensor outputs outputs, _, _ = extract_tensors(all_objects=outputs) # Construct arguments to autograd.backward(). # This is usually just outputs and grads, but forward() can return tensors that # are not differentiable. output_tensors = [] grad_tensors = [] for out, grad in zip(outputs, grads): if out.requires_grad: output_tensors.append(out) grad_tensors.append(grad) see_memory_usage("In backward checkpointing code before backward", force=False) torch.autograd.backward(output_tensors, grad_tensors) see_memory_usage("After backward checkpointing code after backward", force=False) if PROFILE_TIME: timers('backward').stop() timers.log(['backward']) if SYNCHRONIZE: torch.cuda.synchronize() ret_list = [None, None] # first None for ctx for inp in detached_inputs: if torch.is_tensor(inp): ret_list.append(inp.grad) else: ret_list.append(None) return tuple(ret_list)
def train_CcGAN(kernel_sigma, kappa, train_images, train_labels, netG, netD, save_images_folder, save_models_folder=None, clip_label=False): ''' Note that train_images are not normalized to [-1,1] ''' netG = netG.to(device) netD = netD.to(device) optimizerG = torch.optim.Adam(netG.parameters(), lr=lr_g, betas=(0.5, 0.999)) optimizerD = torch.optim.Adam(netD.parameters(), lr=lr_d, betas=(0.5, 0.999)) if save_models_folder is not None and resume_niters > 0: save_file = save_models_folder + "/CcGAN_{}_checkpoint_intrain/CcGAN_checkpoint_niters_{}.pth".format( threshold_type, resume_niters) checkpoint = torch.load(save_file) netG.load_state_dict(checkpoint['netG_state_dict']) netD.load_state_dict(checkpoint['netD_state_dict']) optimizerG.load_state_dict(checkpoint['optimizerG_state_dict']) optimizerD.load_state_dict(checkpoint['optimizerD_state_dict']) torch.set_rng_state(checkpoint['rng_state']) #end if ################# unique_train_labels = np.sort(np.array(list(set(train_labels)))) # printed images with labels between the 5-th quantile and 95-th quantile of training labels n_row = 10 n_col = n_row z_fixed = torch.randn(n_row * n_col, dim_gan, dtype=torch.float).to(device) start_label = np.quantile(train_labels, 0.05) end_label = np.quantile(train_labels, 0.95) selected_labels = np.linspace(start_label, end_label, num=n_row) y_fixed = np.zeros(n_row * n_col) for i in range(n_row): curr_label = selected_labels[i] for j in range(n_col): y_fixed[i * n_col + j] = curr_label print(y_fixed) y_fixed = torch.from_numpy(y_fixed).type(torch.float).view(-1, 1).to(device) start_time = timeit.default_timer() for niter in range(resume_niters, niters): ''' Train Discriminator ''' ## randomly draw batch_size_disc y's from unique_train_labels batch_target_labels_in_dataset = np.random.choice(unique_train_labels, size=batch_size_max, replace=True) ## add Gaussian noise; we estimate image distribution conditional on these labels batch_epsilons = np.random.normal(0, kernel_sigma, batch_size_max) batch_target_labels_with_epsilon = batch_target_labels_in_dataset + batch_epsilons if clip_label: batch_target_labels_with_epsilon = np.clip( batch_target_labels_with_epsilon, 0.0, 1.0) batch_target_labels = batch_target_labels_with_epsilon[ 0:batch_size_disc] ## find index of real images with labels in the vicinity of batch_target_labels ## generate labels for fake image generation; these labels are also in the vicinity of batch_target_labels batch_real_indx = np.zeros( batch_size_disc, dtype=int ) #index of images in the datata; the labels of these images are in the vicinity batch_fake_labels = np.zeros(batch_size_disc) for j in range(batch_size_disc): ## index for real images if threshold_type == "hard": indx_real_in_vicinity = np.where( np.abs(train_labels - batch_target_labels[j]) <= kappa)[0] else: # reverse the weight function for SVDL indx_real_in_vicinity = np.where( (train_labels - batch_target_labels[j] )**2 <= -np.log(nonzero_soft_weight_threshold) / kappa)[0] ## if the max gap between two consecutive ordered unique labels is large, it is possible that len(indx_real_in_vicinity)<1 while len(indx_real_in_vicinity) < 1: batch_epsilons_j = np.random.normal(0, kernel_sigma, 1) batch_target_labels[ j] = batch_target_labels_in_dataset[j] + batch_epsilons_j if clip_label: batch_target_labels = np.clip(batch_target_labels, 0.0, 1.0) ## index for real images if threshold_type == "hard": indx_real_in_vicinity = np.where( np.abs(train_labels - batch_target_labels[j]) <= kappa)[0] else: # reverse the weight function for SVDL indx_real_in_vicinity = np.where( (train_labels - batch_target_labels[j])**2 <= -np.log(nonzero_soft_weight_threshold) / kappa)[0] #end while len(indx_real_in_vicinity)<1 assert len(indx_real_in_vicinity) >= 1 batch_real_indx[j] = np.random.choice(indx_real_in_vicinity, size=1)[0] ## labels for fake images generation if threshold_type == "hard": lb = batch_target_labels[j] - kappa ub = batch_target_labels[j] + kappa else: lb = batch_target_labels[j] - np.sqrt( -np.log(nonzero_soft_weight_threshold) / kappa) ub = batch_target_labels[j] + np.sqrt( -np.log(nonzero_soft_weight_threshold) / kappa) lb = max(0.0, lb) ub = min(ub, 1.0) assert lb <= ub assert lb >= 0 and ub >= 0 assert lb <= 1 and ub <= 1 batch_fake_labels[j] = np.random.uniform(lb, ub, size=1)[0] #end for j ## draw the real image batch from the training set batch_real_images = train_images[batch_real_indx] assert batch_real_images.max() > 1 batch_real_labels = train_labels[batch_real_indx] batch_real_labels = torch.from_numpy(batch_real_labels).type( torch.float).to(device) ## normalize real images to [-1,1] trainset = IMGs_dataset(batch_real_images, labels=None, normalize=True) train_dataloader = torch.utils.data.DataLoader( trainset, batch_size=batch_size_disc, shuffle=False) train_dataloader = iter(train_dataloader) batch_real_images = train_dataloader.next() assert len(batch_real_images) == batch_size_disc batch_real_images = batch_real_images.type(torch.float).to(device) assert batch_real_images.max().item() <= 1 ## generate the fake image batch batch_fake_labels = torch.from_numpy(batch_fake_labels).type( torch.float).to(device) z = torch.randn(batch_size_disc, dim_gan, dtype=torch.float).to(device) batch_fake_images = netG(z, batch_fake_labels) ## target labels on gpu batch_target_labels = torch.from_numpy(batch_target_labels).type( torch.float).to(device) ## weight vector if threshold_type == "soft": real_weights = torch.exp( -kappa * (batch_real_labels - batch_target_labels)**2).to(device) fake_weights = torch.exp( -kappa * (batch_fake_labels - batch_target_labels)**2).to(device) else: real_weights = torch.ones(batch_size_disc, dtype=torch.float).to(device) fake_weights = torch.ones(batch_size_disc, dtype=torch.float).to(device) #end if threshold type # forward pass real_dis_out = netD(batch_real_images, batch_target_labels) fake_dis_out = netD(batch_fake_images.detach(), batch_target_labels) d_loss = -torch.mean( real_weights.view(-1) * torch.log(real_dis_out.view(-1) + 1e-20)) - torch.mean( fake_weights.view(-1) * torch.log(1 - fake_dis_out.view(-1) + 1e-20)) optimizerD.zero_grad() d_loss.backward() optimizerD.step() ''' Train Generator ''' netG.train() # generate fake images batch_target_labels = batch_target_labels_with_epsilon[ 0:batch_size_gene] batch_target_labels = torch.from_numpy(batch_target_labels).type( torch.float).to(device) z = torch.randn(batch_size_gene, dim_gan, dtype=torch.float).to(device) batch_fake_images = netG(z, batch_target_labels) # loss dis_out = netD(batch_fake_images, batch_target_labels) g_loss = -torch.mean(torch.log(dis_out + 1e-20)) # backward optimizerG.zero_grad() g_loss.backward() optimizerG.step() # print loss if (niter + 1) % 20 == 0: print( "CcGAN: [Iter %d/%d] [D loss: %.4e] [G loss: %.4e] [real prob: %.3f] [fake prob: %.3f] [Time: %.4f]" % (niter + 1, niters, d_loss.item(), g_loss.item(), real_dis_out.mean().item(), fake_dis_out.mean().item(), timeit.default_timer() - start_time)) if (niter + 1) % 100 == 0: netG.eval() with torch.no_grad(): gen_imgs = netG(z_fixed, y_fixed) gen_imgs = gen_imgs.detach().cpu() save_image(gen_imgs.data, save_images_folder + '/{}.png'.format(niter + 1), nrow=n_row, normalize=True) if save_models_folder is not None and ( (niter + 1) % save_niters_freq == 0 or (niter + 1) == niters): save_file = save_models_folder + "/CcGAN_{}_checkpoint_intrain/CcGAN_checkpoint_niters_{}.pth".format( threshold_type, niter + 1) os.makedirs(os.path.dirname(save_file), exist_ok=True) torch.save( { 'netG_state_dict': netG.state_dict(), 'netD_state_dict': netD.state_dict(), 'optimizerG_state_dict': optimizerG.state_dict(), 'optimizerD_state_dict': optimizerD.state_dict(), 'rng_state': torch.get_rng_state() }, save_file) #end for niter return netG, netD
def get_rng_state(): state = {"torch_rng_state": torch.get_rng_state()} if torch.cuda.is_available(): state["cuda_rng_state"] = torch.cuda.get_rng_state() return state
loss = criterion(outputs[:, 0], labels).cpu().detach().numpy() # calculate loss print(f'Validation loss = {loss.item()}') valLoss += loss allValLoss[epoch] = valLoss / len(val_dataloader) np.savetxt(out_path + '/allValLoss.txt', allValLoss) np.savetxt(out_path + '/allTrainLoss.txt', allTrainLoss) if allValLoss[epoch] < bestValLoss: print( f'Best seen validation performance ({bestValLoss} -> {allValLoss[epoch]}), saving...' ) torch.save(model_2.state_dict(), out_path + bestValLossNetFileName) np.savetxt(out_path + '/bestEpochNum.txt', np.array([epoch])) bestValLoss = allValLoss[epoch] # checkpointing at the end of every epoch torch.save(model_2.state_dict(), out_path + currModelFilename) np.savetxt(f'{out_path}lastCompletedEpoch.txt', np.asarray([epoch])) np.savetxt(f'{out_path}randomState.txt', torch.get_rng_state().numpy()) model_2 = model_2.train() print(f'Epoch = {epoch} finished') print('Finished Training') end = time.time() print(f'Training took {end-start} seconds')
def freeze(self): if 'torch' in sys.modules: import torch return torch.get_rng_state() else: return None
def get_rng_state(): state = {'rng_state': torch.get_rng_state()} if torch.cuda.is_available(): state['cuda_rng_state'] = torch.cuda.get_rng_state() return state
def forward(ctx, run_function, *args): global mpu, timers, SYNCHRONIZE, PROFILE_TIME if SYNCHRONIZE: torch.cuda.synchronize() if timers is None and PROFILE_TIME: timers = Timers() if PROFILE_TIME: timers('forward').start() ctx.run_function = run_function global num_layers global mp_rank, mp_size, mp_group global contiguous_data_buffers, contiguous_size_buffers global data_offsets, size_offsets if mp_rank is None: if mpu is not None: mp_rank = mpu.get_model_parallel_rank() mp_size = mpu.get_model_parallel_world_size() mp_group = mpu.get_model_parallel_group() else: mp_rank = 0 mp_size = 1 mp_group = None global cuda_device, transport_stream, PARTITION_ACTIVATIONS, buffer_0, buffer_1, buffer_0_offset, buffer_1_offset if cuda_device is None: see_memory_usage("First Forward Begining", force=True) if dist.get_rank() == 0: logger.info(f"Activation Checkpointing Information") logger.info( f"----Partition Activations {PARTITION_ACTIVATIONS}, CPU CHECKPOINTING {PA_TO_CPU}" ) logger.info( f"----contiguous Memory Checkpointing {CONTIGUOUS_CHECKPOINTING} with {num_layers} total layers" ) logger.info(f"----Synchronization {SYNCHRONIZE}") logger.info(f"----Profiling {PROFILE_TIME}") cuda_device = torch.cuda.current_device() transport_stream = torch.cuda.Stream(device=cuda_device) if PARTITION_ACTIVATIONS: #inputs = [item.detach().contiguous().view(-1).narrow(0, get_partition_start(item), get_partition_size(item)).clone() for item in args[:-1]] #inputs.append(args[-1]) inputs = [] for i, item in enumerate(args[:-1]): partition_size = get_partition_size(item) partition = item.detach().contiguous().view(-1).narrow( 0, get_partition_start(item), partition_size).clone() if CONTIGUOUS_CHECKPOINTING: buffer_device = torch.device( 'cpu') if PA_TO_CPU else partition.device if i >= len(contiguous_data_buffers): tensor_list = [ torch.tensor(()).new_empty([partition_size], dtype=partition.dtype, device=buffer_device) for i in range(num_layers) ] contiguous_data_buffers.append(tensor_list) data_offsets.append(0) elif contiguous_data_buffers[i] is None: tensor_list = [ torch.tensor(()).new_empty([partition_size], dtype=partition.dtype, device=buffer_device) for i in range(num_layers) ] contiguous_data_buffers[i] = tensor_list data_offsets[i] = 0 contiguous_partition = contiguous_data_buffers[i][ data_offsets[i]].data.copy_(partition.data) data_offsets[i] = data_offsets[i] + 1 inputs.append(contiguous_partition) else: partition = partition.cpu() if PA_TO_CPU else partition inputs.append(partition) inputs.append(args[-1]) #just in case something funky is happening such as reuse of inputs inputs_cuda = [item.to(cuda_device) for item in args] # Copy the rng states. ctx.fwd_cpu_rng_state = torch.get_rng_state() ctx.fwd_cuda_rng_state = torch.cuda.get_rng_state() ctx.fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states() #ctx.save_for_backward(*args) with torch.no_grad(): outputs = run_function(*inputs_cuda) del inputs_cuda #with torch.cuda.stream(transport_stream): #if PARTITION_ACTIVATIONS: # new_args = [] # for arg, inp in zip(args,inputs): # size= torch.tensor(arg.size()) # arg.data = inp.data # new_args.append(arg) # new_args.append(size) # ctx.save_for_backward(*new_args) if PARTITION_ACTIVATIONS: new_args = [] for i, (arg, inp) in enumerate(zip(args, inputs)): size = torch.tensor(arg.size()) arg.data = inp.data new_args.append(arg) if CONTIGUOUS_CHECKPOINTING: numel = size.numel() if i >= len(contiguous_size_buffers): tmp = torch.tensor(()) contiguous_size_buffers.append( tmp.new_empty([numel * num_layers], dtype=size.dtype, device=size.device)) size_offsets.append(0) elif contiguous_size_buffers[i] is None: tmp = torch.tensor(()) contiguous_size_buffers[i] = tmp.new_empty( [numel * num_layers], dtype=size.dtype, device=size.device) size_offsets[i] = 0 contiguous_size = contiguous_size_buffers[i].narrow( 0, size_offsets[i], numel).data.copy_(size.data) contiguous_size = contiguous_size.view_as(size) size_offsets[i] = size_offsets[i] + numel new_args.append(contiguous_size) else: new_args.append(size) #if dist.get_rank() == 0: # logger.info(f"The stored tensor is {contiguous_size} and orginal one is {size} ") ctx.save_for_backward(*new_args) else: ctx.save_for_backward(*args) if PROFILE_TIME: timers('forward').stop() timers.log(['forward']) if SYNCHRONIZE: torch.cuda.synchronize() return outputs
def train(total_iters=0, skipped_iters=0, elapsed_time=False): # Turn on training mode which enables dropout. model.train() total_loss = 0 start_time = time.time() t0 = start_time ntokens = args.data_size hidden = init_hidden(args.batch_size) curr_loss = 0. distributed = isinstance(model, DDP) max_iters = len(train_data) def log(epoch, i, lr, ms_iter, total_time, loss, scale): print( '| epoch {:3d} | {:5d}/{:5d} batches | lr {:.2E} | ms/batch {:.3E} | total time {:.3E}\ loss {:.2E} | ppl {:8.2f} | loss scale {:8.2f}'.format( epoch, i, max_iters, lr, ms_iter, total_time, loss, math.exp(min(loss, 20)), scale)) for i, batch in enumerate(train_data): data, targets, reset_mask = get_batch(batch) optim.zero_grad() output, hidden = model(data, reset_mask=reset_mask) loss = criterion( output.view(-1, ntokens).contiguous().float(), targets.view(-1).contiguous()) total_loss += loss.data.float() if args.fp16: optim.backward(loss, update_master_grads=False) else: loss.backward() if distributed: torch.distributed.all_reduce(loss.data) loss.data /= args.world_size model.allreduce_params() # clipping gradients helps prevent the exploding gradient problem in RNNs / LSTMs. if args.clip > 0: if not args.fp16: torch.nn.utils.clip_grad_norm(model.parameters(), args.clip) else: optim.clip_master_grads(clip=args.clip) if args.fp16: optim.update_master_grads() optim.step() # step learning rate and log training progress lr = optim.param_groups[0]['lr'] if not args.fp16: LR.step() else: # if fp16 optimizer skips gradient step due to explosion do not step lr if not optim.overflow: LR.step() else: skipped_iters += 1 # log current results if ((i + 1) % args.log_interval == 0) and (i != max_iters - 1): cur_loss = total_loss[0] / args.log_interval cur_time = time.time() elapsed = cur_time - start_time total_elapsed = cur_time - t0 + elapsed_time log(epoch, i + 1, lr, elapsed * 1000 / args.log_interval, total_elapsed, cur_loss, args.loss_scale if not args.fp16 else optim.loss_scale) total_loss = 0 start_time = cur_time sys.stdout.flush() # save current model progress. If distributed only save from worker 0 if args.save_iters and total_iters % args.save_iters == 0 and total_iters > 0 and args.rank < 1: if args.rank < 1: with open( os.path.join( os.path.splitext(args.save)[0], 'e%s.pt' % (str(total_iters), )), 'wb') as f: torch.save(model.state_dict(), f) if args.save_optim: with open( os.path.join( os.path.splitext(args.save)[0], 'optim.pt'), 'wb') as f: optim_sd = optim.state_dict() optim_sd['iter'] = total_iters optim_sd['skipped_iter'] = skipped_iters torch.save(optim_sd, f) del optim_sd with open( os.path.join( os.path.splitext(args.save)[0], 'rng.pt'), 'wb') as f: torch.save((torch.cuda.get_rng_state(), torch.get_rng_state()), f) if args.cuda: torch.cuda.synchronize() total_iters += 1 #final logging elapsed_iters = max_iters % args.log_interval if elapsed_iters == 0: elapsed_iters = args.log_interval cur_loss = total_loss[0] / elapsed_iters cur_time = time.time() elapsed = cur_time - start_time total_elapsed = cur_time - t0 + elapsed_time log(epoch, max_iters, lr, elapsed * 1000 / elapsed_iters, total_elapsed, cur_loss, args.loss_scale if not args.fp16 else optim.loss_scale) return cur_loss, skipped_iters
def train_AE(): # define optimizer params = list(net_encoder.parameters()) + list(net_decoder.parameters()) optimizer = torch.optim.Adam(params, lr=base_lr, betas=(0.5, 0.999), weight_decay=1e-4) # criterion criterion = nn.MSELoss() if resume_epoch > 0: print("Loading ckpt to resume training AE >>>") ckpt_fullpath = save_models_folder + "/AE_checkpoint_intrain/AE_checkpoint_epoch_{}_lambda_{}.pth".format( resume_epoch, lambda_sparsity) checkpoint = torch.load(ckpt_fullpath) net_encoder.load_state_dict(checkpoint['net_encoder_state_dict']) net_decoder.load_state_dict(checkpoint['net_decoder_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) torch.set_rng_state(checkpoint['rng_state']) gen_iterations = checkpoint['gen_iterations'] else: gen_iterations = 0 start_time = timeit.default_timer() for epoch in range(resume_epoch, epochs): adjust_learning_rate(epoch, epochs, optimizer, base_lr, lr_decay_epochs, lr_decay_factor) train_loss = 0 for batch_idx, batch_real_images in enumerate(trainloader): net_encoder.train() net_decoder.train() batch_size_curr = batch_real_images.shape[0] batch_real_images = batch_real_images.type(torch.float).cuda() batch_features = net_encoder(batch_real_images) batch_recons_images = net_decoder(batch_features) ''' based on https://debuggercafe.com/sparse-autoencoders-using-l1-regularization-with-pytorch/ ''' loss = criterion( batch_recons_images, batch_real_images) + lambda_sparsity * batch_features.mean() #backward pass optimizer.zero_grad() loss.backward() optimizer.step() train_loss += loss.cpu().item() gen_iterations += 1 if gen_iterations % 100 == 0: n_row = min(10, int(np.sqrt(batch_size_curr))) with torch.no_grad(): batch_recons_images = net_decoder( net_encoder(batch_real_images[0:n_row**2])) batch_recons_images = batch_recons_images.detach().cpu() save_image(batch_recons_images.data, save_AE_images_in_train_folder + '/{}.png'.format(gen_iterations), nrow=n_row, normalize=True) if gen_iterations % 20 == 0: print( "AE+lambda{}: [step {}] [epoch {}/{}] [train loss {}] [Time {}]" .format(lambda_sparsity, gen_iterations, epoch + 1, epochs, train_loss / (batch_idx + 1), timeit.default_timer() - start_time)) # end for batch_idx if (epoch + 1) % 50 == 0: save_file = save_models_folder + "/AE_checkpoint_intrain/AE_checkpoint_epoch_{}_lambda_{}.pth".format( epoch + 1, lambda_sparsity) os.makedirs(os.path.dirname(save_file), exist_ok=True) torch.save( { 'gen_iterations': gen_iterations, 'net_encoder_state_dict': net_encoder.state_dict(), 'net_decoder_state_dict': net_decoder.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'rng_state': torch.get_rng_state() }, save_file) #end for epoch return net_encoder, net_decoder
def train_model(model, datasets, criterion, optimizer): global args, writer since = time.time() curr_loss = 0 lr = args.learning_rate flag = False start_itr = 0 num_running_batch = 0 running_batch = { 'previmg': torch.Tensor(batchSize, 3, input_size, input_size), 'currimg': torch.Tensor(batchSize, 3, input_size, input_size), 'previmg_x2': torch.Tensor(batchSize, 3, input_size * 2, input_size * 2), 'currimg_x2': torch.Tensor(batchSize, 3, input_size * 2, input_size * 2), 'currbb': torch.Tensor(batchSize, 4) } scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_decay_step, gamma=args.gamma) # resume from a checkpoint if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) start_itr = checkpoint['itr'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) scheduler.load_state_dict(checkpoint['scheduler']) num_running_batch = checkpoint['num_running_batch'] running_batch = checkpoint['running_batch'] lr = checkpoint['lr'] np.random.set_state(checkpoint['np_rand_state']) torch.set_rng_state(checkpoint['torch_rand_state']) print("=> loaded checkpoint '{}' (iteration {})".format( args.resume, checkpoint['itr'])) else: print("=> no checkpoint found at '{}'".format(args.resume)) if not os.path.isdir(args.save_directory): os.makedirs(args.save_directory) itr = start_itr st = time.time() while itr < args.num_batches: model.train() if (args.resume and os.path.isfile(args.resume) and itr == start_itr and (not flag)): checkpoint = torch.load(args.resume) i = checkpoint['dataset_indx'] flag = True else: i = 0 # train on datasets # usually ALOV and ImageNet while i < len(datasets): dataset = datasets[i] i = i + 1 (running_batch, train_batch, done, num_running_batch) = get_training_batch(num_running_batch, running_batch, dataset) # print(i, num_running_batch, done) if done: scheduler.step() # load sample x1 = train_batch['previmg'].to(device) x2 = train_batch['currimg'].to(device) x1_x2 = train_batch['previmg_x2'].to(device) x2_x2 = train_batch['currimg_x2'].to(device) y = train_batch['currbb'].requires_grad_(False).to(device) # zero the parameter gradients optimizer.zero_grad() # forward output = model(x1, x2, x1_x2, x2_x2) loss = criterion(output, y) # backward + optimize loss.backward() optimizer.step() # statistics curr_loss = loss.item() end = time.time() itr = itr + 1 print('[training] step = %d/%d, loss = %f, time = %f' % (itr, args.num_batches, curr_loss, end - st)) sys.stdout.flush() del (train_batch) st = time.time() if enable_tensorboard: writer.add_scalar('train/batch_loss', curr_loss, itr) if itr > 0 and itr % kSaveModel == 0: path = os.path.join( args.save_directory, 'model_itr_' + str(itr) + '_loss_' + str(round(curr_loss, 3)) + '.pth.tar') save_checkpoint( { 'itr': itr, 'np_rand_state': np.random.get_state(), 'torch_rand_state': torch.get_rng_state(), 'l1_loss': curr_loss, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'scheduler': scheduler.state_dict(), 'num_running_batch': num_running_batch, 'running_batch': running_batch, 'lr': lr, 'dataset_indx': i }, path) time_elapsed = time.time() - since print('Training complete in {:.0f}m {:.0f}s'.format( time_elapsed // 60, time_elapsed % 60)) if enable_tensorboard: writer.export_scalars_to_json("./all_scalars.json") writer.close() return model
def _n_features(self) -> int: """ Calculates the number of extracted features going into the the classifier part. :return: Number of features. """ def conv_size(h_in, w_in): if "kernel_size" in self.conv_params.keys(): kernel_size = (self.conv_params['kernel_size'], self.conv_params['kernel_size']) if isinstance( self.conv_params['kernel_size'], int) else self.conv_params['kernel_size'] else: kernel_size = (self.kernel_size, self.kernel_size) if isinstance( self.kernel_size, int) else self.kernel_size if "padding" in self.conv_params.keys(): padding = (self.conv_params['padding'], self.conv_params['padding']) if isinstance( self.conv_params['padding'], int) else self.conv_params['padding'] else: padding = (0, 0) if "dilation" in self.conv_params.keys(): dilation = (self.conv_params['dilation'], self.conv_params['dilation']) if isinstance( self.conv_params['dilation'], int) else self.conv_params['dilation'] else: dilation = (1, 1) if "stride" in self.conv_params.keys(): stride = (self.conv_params['stride'], self.conv_params['stride']) if isinstance( self.conv_params['stride'], int) else self.conv_params['stride'] else: stride = (1, 1) # print("========convo===========") # print("kernel = " + str(kernel_size)) # print("padding = " + str(padding)) # print("dilation = " + str(dilation)) # print("stride = " + str(stride)) h_out = floor(((h_in + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) / stride[0]) + 1) w_out = floor(((w_in + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1) / stride[1]) + 1) # print(h_out , w_out) return h_out, w_out def pool_size(h_in, w_in): if "kernel_size" in self.pooling_params.keys(): kernel_size = ( self.pooling_params['kernel_size'], self.pooling_params['kernel_size']) if isinstance( self.pooling_params['kernel_size'], int) else self.pooling_params['kernel_size'] # else: # kernel_size = (self.kernel_size, self.kernel_size) if isinstance( # self.kernel_size, int) else self.kernel_size if "padding" in self.pooling_params.keys(): padding = (self.pooling_params['padding'], self.pooling_params['padding']) if isinstance( self.pooling_params['padding'], int) else self.pooling_params['padding'] else: padding = (0, 0) if "dilation" in self.pooling_params.keys(): dilation = (self.pooling_params['dilation'], self.pooling_params['dilation']) if isinstance( self.pooling_params['dilation'], int) else self.pooling_params['dilation'] else: dilation = (1, 1) if "stride" in self.pooling_params.keys(): stride = (self.pooling_params['stride'], self.pooling_params['stride']) if isinstance( self.pooling_params['stride'], int) else self.pooling_params['stride'] else: stride = kernel_size # print("========pooling===========") # print("kernel = " + str(kernel_size)) # print("padding = " + str(padding)) # print("dilation = " + str(dilation)) # print("stride = " + str(stride)) h_out = floor(((h_in + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) / stride[0]) + 1) w_out = floor(((w_in + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1) / stride[1]) + 1) # print(h_out , w_out) return h_out, w_out # Make sure to not mess up the random state. rng_state = torch.get_rng_state() try: # ====== YOUR CODE: ====== _, in_h, in_w, = tuple(self.in_size) from math import floor for channel in range(len(self.channels)): # print(channel) in_h, in_w = conv_size(in_h, in_w) if (channel + 1) % self.pool_every == 0: # if channel > 0 and channel % self.pool_every == 0: in_h, in_w = pool_size(in_h, in_w) return in_h * in_w * self.channels[-1] # return self.channels[-1] * ceil((in_h * in_w) / ((self.pooling_params['kernel_size'] * self.pooling_params['kernel_size']) ** (len(self.channels) // self.pool_every))) # return self.channels[-1] * self.conv_params['kernel_size'] * self.conv_params['kernel_size'] # raise NotImplementedError() # ======================== finally: torch.set_rng_state(rng_state)
def _randomize_labels(labels, seed): rng_state = torch.get_rng_state() torch.manual_seed(seed) labels = list(torch.IntTensor(labels)[torch.randperm(len(labels))]) torch.set_rng_state(rng_state) return labels
def _checkpoint_without_reentrant(function, preserve_rng_state=True, *args, **kwargs): """Checkpointining without re-entrant autograd Args: function: describes what to run in the forward pass of the model or part of the model. It should also know how to handle the inputs passed as the tuple. For example, in LSTM, if user passes ``(activation, hidden)``, :attr:`function` should correctly use the first input as ``activation`` and the second input as ``hidden`` preserve_rng_state(bool, optional): Omit stashing and restoring the RNG state during each checkpoint. Default: ``True`` *args: Arguments to pass in to the given ``function``. **kwargs: Keyword arguments to pass into the given ``function``. """ # Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu. gpu_autocast_kwargs, cpu_autocast_kwargs = _get_autocast_kwargs() if preserve_rng_state: fwd_cpu_state = torch.get_rng_state() # Don't eagerly initialize the cuda context by accident. # (If the user intends that the context is initialized later, within their # run_function, we SHOULD actually stash the cuda state here. Unfortunately, # we have no way to anticipate this will happen before we run the function. # If they do so, we raise an error.) had_cuda_in_fwd = False if torch.cuda._initialized: had_cuda_in_fwd = True fwd_gpu_devices, fwd_gpu_states = get_device_states(*args) # Custom class to be able to take weak references class Holder(): pass # The Holder object for each of the saved object is saved directly on the # SavedVariable and is cleared when reset_data() is called on it. We MUST make # sure that this is the only object having an owning reference to ensure that # the Tensor stored in storage is deleted as soon as the corresponding SavedVariable # data is cleared. storage: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary() weak_holder_list = [] def pack(x): # TODO(varal7): Instead of returning abstract object, we can return things metadata (such as # size, device, ...) to catch certain cases of undeterministic behavior of the forward res = Holder() weak_holder_list.append(weakref.ref(res)) return res def unpack(x): unpack_counter = 0 if len(storage) == 0: def inner_pack(inner): nonlocal unpack_counter unpack_counter += 1 # If the holder went out of scope, the SavedVariable is dead and so # the value will never be read from the storage. Skip filling it. if weak_holder_list[unpack_counter - 1]() is None: return # Use detach here to ensure we don't keep the temporary autograd # graph created during the second forward storage[weak_holder_list[unpack_counter - 1]()] = inner.detach() return def inner_unpack(packed): raise RuntimeError( "You are calling backwards on a tensor that is never exposed. Please open an issue." ) # Stash the surrounding rng state, and mimic the state that was # present at this time during forward. Restore the surrounding state # when we're done. rng_devices = [] if preserve_rng_state and had_cuda_in_fwd: rng_devices = fwd_gpu_devices with torch.random.fork_rng(devices=rng_devices, enabled=preserve_rng_state): if preserve_rng_state: torch.set_rng_state(fwd_cpu_state) if had_cuda_in_fwd: set_device_states(fwd_gpu_devices, fwd_gpu_states) with torch.enable_grad(), \ torch.cuda.amp.autocast(**gpu_autocast_kwargs), \ torch.cpu.amp.autocast(**cpu_autocast_kwargs), \ torch.autograd.graph.saved_tensors_hooks(inner_pack, inner_unpack): _unused = function(*args, **kwargs) if x not in storage: raise RuntimeError( "Attempt to retrieve a tensor saved by autograd multiple times without checkpoint" " recomputation being triggered in between, this is not currently supported. Please" " open an issue with details on your use case so that we can prioritize adding this." ) return storage[x] with torch.autograd.graph.saved_tensors_hooks(pack, unpack): output = function(*args, **kwargs) if torch.cuda._initialized and preserve_rng_state and not had_cuda_in_fwd: # Cuda was not initialized before running the forward, so we didn't # stash the CUDA state. raise RuntimeError( "PyTorch's CUDA state was initialized in the forward pass " "of a Checkpoint, which is not allowed. Please open an issue " "if you need this feature.") return output
def setUp(self): if os.getenv("UNLOCK_SEED") is None or os.getenv( "UNLOCK_SEED").lower() == "false": self.rng_state = torch.get_rng_state() torch.manual_seed(0)
def evaluate(args, model, tokenizer, prefix="", test=False, mc=True, generation=False): eval_task_names = (args.task_name, ) eval_outputs_dirs = (args.output_dir, ) results = {} for eval_task, eval_output_dir in zip(eval_task_names, eval_outputs_dirs): eval_dataset = load_and_cache_examples(args, eval_task, tokenizer, evaluate=not test, test=test) if not os.path.exists(eval_output_dir) and args.local_rank in [-1, 0]: os.makedirs(eval_output_dir) args.eval_batch_size = args.per_gpu_eval_batch_size * max( 1, args.n_gpu) # Note that DistributedSampler samples randomly eval_sampler = SequentialSampler( eval_dataset) if args.local_rank == -1 else DistributedSampler( eval_dataset) eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size) # Eval! logger.info("***** Running evaluation {} *****".format(prefix)) logger.info(" Num examples = %d", len(eval_dataset)) logger.info(" Batch size = %d", args.eval_batch_size) eval_avg_ppl = 0.0 eval_ppl = 0.0 nb_eval_steps = 1e-8 nb_eval_time_steps = 1e-8 preds = None preds_normalized = None preds_mc_head = None out_label_ids = None if generation: output_file = open( os.path.join(eval_output_dir, "dev_generated_ans.tsv"), 'w') random_state = torch.get_rng_state() for batch in tqdm(eval_dataloader, desc="Evaluating"): model.eval() batch = tuple(t.to(args.device) for t in batch) with torch.no_grad(): inputs = { 'input_ids': batch[0][:, 1:, :].contiguous().view(-1, batch[0].size(2)), 'output_ids': batch[1][:, 1:, :].contiguous().view(-1, batch[1].size(2)), 'attention_mask': batch[2][:, 1:, :].contiguous().view(-1, batch[2].size(2)), 'token_type_ids': batch[3] if args.model_type in ['xlnet'] else None, # XLM don't use segment_ids } mc_inputs = { 'input_ids': batch[0].view(-1, batch[0].size(2)), 'output_ids': batch[5].view(-1, batch[5].size(2)), 'attention_mask': batch[2].view(-1, batch[2].size(2)), 'token_type_ids': batch[3] if args.model_type in ['xlnet'] else None, # XLM don't use segment_ids } #output_ids = mc_inputs['input_ids'].clone().detach() #output_ids = output_ids[:, 1:] #output_ids[output_ids == tokenizer.pad_token_id] = -1 #output_ids[output_ids == tokenizer.cls_token_id] = -1 #mc_inputs['output_ids'] = output_ids outputs = model(**inputs) if mc: logits_mc_outputs = model(**mc_inputs) mc_outputs = logits_mc_outputs mc_score = (-1 * mc_outputs.sum(1)).view( -1, batch[0].size(1)).data.cpu().numpy() mc_prob = F.softmax((-1 * mc_outputs.sum(1)).view(-1, batch[0].size(1)), -1).data.cpu().numpy() mc_score_normalized = mc_outputs.sum(1) / ( (mc_outputs != 0).float().sum(1)) mc_score_normalized = (mc_score_normalized * -1).view( -1, batch[0].size(1)).data.cpu().numpy() if preds is None: preds = mc_score preds_normalized = mc_score_normalized out_label_ids = batch[4].detach().cpu().numpy() else: preds = np.append(preds, mc_score, axis=0) preds_normalized = np.append(preds_normalized, mc_score_normalized, axis=0) out_label_ids = np.append( out_label_ids, batch[4].detach().cpu().numpy(), axis=0) #output_ids = mc_inputs['input_ids'].clone().detach() #output_ids = output_ids[:, 1:] #output_ids[output_ids == tokenizer.pad_token_id] = -1 #output_ids[output_ids == tokenizer.cls_token_id] = -1 #mc_inputs['output_ids'] = output_ids outputs = model(**inputs) batch_total_ppl = torch.exp( outputs.sum(1) / ((outputs != 0).float().sum(1))).sum() batch_total_loss = outputs.sum() eval_avg_ppl += batch_total_ppl.item() eval_ppl += batch_total_loss.item() if generation: distractor_size = 3 sep_id = tokenizer.eos_token_id for i in range(0, inputs['input_ids'].size(0), distractor_size): tmp_output_ids = inputs['output_ids'][i, :] question_length = (tmp_output_ids != -1).nonzero().min() + 1 distractors = set({}) for _ in range(distractor_size): question, distractor = model.generate( inputs['input_ids'][i:i + 1, :question_length], 30, sample=True, tmp=1.0, top_p=1.0, label=inputs['input_ids'][i:i + 1, question_length:]) while distractor in distractors: question, distractor = model.generate( inputs['input_ids'][i:i + 1, :question_length], 30, sample=True, tmp=1.0, top_p=1.0, label=inputs['input_ids'] [i:i + 1, question_length:]) distractors.add(distractor) res = [question] + list(distractors) output_file.write("\t".join(res) + "\n") nb_eval_steps += outputs.size(0) nb_eval_time_steps += ((outputs != 0).float().sum()).item() torch.set_rng_state(random_state) if generation: output_file.close() eval_avg_ppl = eval_avg_ppl / nb_eval_steps eval_ppl = exp(eval_ppl / nb_eval_time_steps) if mc: preds = F.softmax(torch.tensor(preds), -1).detach().cpu().numpy() np.save(os.path.join(eval_output_dir, "preds.npy"), preds) preds = np.argmax(preds, axis=1) preds_normalized = np.argmax(preds_normalized, axis=1) acc = simple_accuracy(preds, out_label_ids) acc_normalized = simple_accuracy(preds_normalized, out_label_ids) result = { "eval_ppl": eval_ppl, "eval_avg_ppl": eval_avg_ppl, "eval_acc": acc, "eval_acc_normalized": acc_normalized } else: result = {"eval_ppl": eval_ppl, "eval_avg_ppl": eval_avg_ppl} results.update(result) output_eval_file = os.path.join( eval_output_dir, "is_test_" + str(test).lower() + "_eval_results.txt") with open(output_eval_file, "w") as writer: logger.info("***** Eval results {} *****".format( str(prefix) + " is test:" + str(test))) writer.write("model =%s\n" % str(args.model_name_or_path)) writer.write("total batch size=%d\n" % (args.per_gpu_train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1))) writer.write("train num epochs=%d\n" % args.num_train_epochs) writer.write("fp16 =%s\n" % args.fp16) writer.write("max seq length =%d\n" % args.input_max_seq_length) for key in sorted(result.keys()): logger.info(" %s = %s", key, str(result[key])) writer.write("%s = %s\n" % (key, str(result[key]))) return results
def backward(ctx, *args): global timers #see_memory_usage("In backward", force=True) #removing pointers to the contiguous buffer memory #so that they can be garbage collected once the checkpoints #have been used if SYNCHRONIZE: torch.cuda.synchronize() if PROFILE_TIME: timers('backward').start() if CONTIGUOUS_CHECKPOINTING: global data_offsets, size_offsets global contiguous_data_buffers, contiguous_size_buffers for buffers in contiguous_data_buffers: buffers = [] #frees up all the pointers to the checkpoints except for the ones #stored by save for backward contiguous_data_buffers = [] contiguous_size_buffers = [] data_offsets = [] size_offsets = [] #see_memory_usage("In backward checkpointing code", force=True) if not torch.autograd._is_checkpoint_valid(): raise RuntimeError("Checkpointing is not compatible with .grad(), " "please use .backward() if possible") global cuda_device, transport_stream, PARTITION_ACTIVATIONS if PARTITION_ACTIVATIONS: #with torch.cuda.stream(transport_stream): inputs = get_full_inputs(ctx.saved_tensors, device=cuda_device if PA_TO_CPU else None) detached_inputs = detach_variable(inputs) else: inputs = ctx.saved_tensors detached_inputs = detach_variable(inputs) # Store the current states. bwd_cpu_rng_state = torch.get_rng_state() bwd_cuda_rng_state = torch.cuda.get_rng_state() bwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states() # Set the states to what it used to be before the forward pass. torch.set_rng_state(ctx.fwd_cpu_rng_state) _set_cuda_rng_state(ctx.fwd_cuda_rng_state) get_cuda_rng_tracker().set_states(ctx.fwd_cuda_rng_state_tracker) # if PARTITION_ACTIVATIONS: # current_stream=torch.cuda.current_stream() # current_stream.wait_stream(transport_stream) with torch.enable_grad(): outputs = ctx.run_function(*detached_inputs) # Set the states back to what it was at the start of this function. torch.set_rng_state(bwd_cpu_rng_state) _set_cuda_rng_state(bwd_cuda_rng_state) get_cuda_rng_tracker().set_states(bwd_cuda_rng_state_tracker) if isinstance(outputs, torch.Tensor): outputs = (outputs, ) torch.autograd.backward(outputs, args) if PROFILE_TIME: timers('backward').stop() timers.log(['backward']) if SYNCHRONIZE: torch.cuda.synchronize() return (None, ) + tuple(inp.grad for inp in detached_inputs)
def main(): parser = argparse.ArgumentParser() # path setting parser.add_argument("--waveforms", required=True, type=str, help="directory or list of wav files") parser.add_argument("--waveforms_eval", required=True, type=str, help="directory or list of evaluation wav files") parser.add_argument("--feats", required=True, type=str, help="directory or list of aux feat files") parser.add_argument("--feats_eval", required=True, type=str, help="directory or list of evaluation aux feat files") parser.add_argument("--stats", required=True, type=str, help="hdf5 file including statistics") parser.add_argument("--expdir", required=True, type=str, help="directory to save the model") # network structure setting parser.add_argument("--n_aux", default=54, type=int, help="number of dimension of aux feats") parser.add_argument("--skip_chn", default=256, type=int, help="number of channels of skip output") parser.add_argument("--seg", default=1, type=int, help="segment size") parser.add_argument("--dilation_depth", default=3, type=int, help="depth of dilation") parser.add_argument("--dilation_repeat", default=2, type=int, help="repeat of dilation depth") parser.add_argument("--hid_chn", default=192, type=int, help="kernel size of dilated causal convolution") parser.add_argument("--kernel_size", default=7, type=int, help="kernel size of dilated causal convolution") parser.add_argument("--aux_kernel_size", default=3, type=int, help="kernel size of dilated causal convolution") parser.add_argument("--aux_dilation_size", default=2, type=int, help="kernel size of dilated causal convolution") parser.add_argument("--upsampling_factor", default=110, type=int, help="upsampling factor of aux features" "(if set 0, do not apply)") parser.add_argument("--n_fft_facts", default=17, type=int, help="kernel size of dilated causal convolution") parser.add_argument("--string_path", default="/feat_org_lf0", type=str, help="directory to save the model") # network training setting parser.add_argument("--lr", default=1e-4, type=float, help="learning rate") parser.add_argument("--batch_size", default=8800, type=int, help="batch size (if set 0, utterance batch will be used)") parser.add_argument("--epoch_count", default=4000, type=int, help="number of training epochs") parser.add_argument("--do_prob", default=0, type=float, help="dropout probability") parser.add_argument("--lpc", default=0, type=int, help="number of linear predictive coefficients for location estimate") parser.add_argument("--aux_conv2d_flag", default=False, type=strtobool, help="flag to use 2d conv of aux") parser.add_argument("--wav_conv_flag", default=False, type=strtobool, help="flag to use 1d conv of wav") # other setting parser.add_argument("--seed", default=1, type=int, help="seed number") parser.add_argument("--resume", default=None, type=str, help="model path to restart training") parser.add_argument("--pretrained", default=None, type=str, help="model path to restart training") parser.add_argument("--GPU_device", default=None, type=int, help="selection of GPU device") parser.add_argument("--verbose", default=1, type=int, help="log level") args = parser.parse_args() if args.GPU_device is not None: os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["CUDA_VISIBLE_DEVICES"] = str(args.GPU_device) # make experimental directory if not os.path.exists(args.expdir): os.makedirs(args.expdir) # set log level if args.verbose == 1: logging.basicConfig(level=logging.INFO, format='%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s', datefmt='%m/%d/%Y %I:%M:%S', filename=args.expdir + "/train.log") logging.getLogger().addHandler(logging.StreamHandler()) elif args.verbose > 1: logging.basicConfig(level=logging.DEBUG, format='%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s', datefmt='%m/%d/%Y %I:%M:%S', filename=args.expdir + "/train.log") logging.getLogger().addHandler(logging.StreamHandler()) else: logging.basicConfig(level=logging.WARN, format='%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s', datefmt='%m/%d/%Y %I:%M:%S', filename=args.expdir + "/train.log") logging.getLogger().addHandler(logging.StreamHandler()) logging.warn("logging is disabled.") # fix seed os.environ['PYTHONHASHSEED'] = str(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) torch.backends.cudnn.benchmark = True #faster #torch.backends.cudnn.deterministic = True #reproducibility_slower #torch.backends.cudnn.benchmark = False #reproducibility_slower # save args as conf torch.save(args, args.expdir + "/model.conf") # define network model = CSWNV( n_aux=args.n_aux, skip_chn=args.skip_chn, hid_chn=args.hid_chn, dilation_depth=args.dilation_depth, dilation_repeat=args.dilation_repeat, kernel_size=args.kernel_size, aux_kernel_size=args.aux_kernel_size, aux_dilation_size=args.aux_dilation_size, do_prob=args.do_prob, seg=args.seg, lpc=args.lpc, aux_conv2d_flag=args.aux_conv2d_flag, wav_conv_flag=args.wav_conv_flag, upsampling_factor=args.upsampling_factor) logging.info(model) criterion_lsd = LSDloss() criterion_laplace = LaplaceLoss() # define transforms string_path_name = args.string_path.split('feat_')[1] logging.info(string_path_name) scaler = StandardScaler() if check_hdf5(args.stats, "/mean_"+string_path_name): scaler.mean_ = read_hdf5(args.stats, "/mean_"+string_path_name) scaler.scale_ = read_hdf5(args.stats, "/scale_"+string_path_name) elif check_hdf5(args.stats, "/mean_"+args.string_path): scaler.mean_ = read_hdf5(args.stats, "/mean_"+args.string_path) scaler.scale_ = read_hdf5(args.stats, "/scale_"+args.string_path) else: scaler.mean_ = read_hdf5(args.stats, "/mean_feat_"+string_path_name) scaler.scale_ = read_hdf5(args.stats, "/scale_feat_"+string_path_name) mean_src = torch.FloatTensor(scaler.mean_) std_src = torch.FloatTensor(scaler.scale_) # send to gpu if torch.cuda.is_available(): model.cuda() criterion_lsd.cuda() criterion_laplace.cuda() mean_src = mean_src.cuda() std_src = std_src.cuda() else: logging.error("gpu is not available. please check the setting.") sys.exit(1) model.train() model.apply(initialize) model.scale_in.weight = torch.nn.Parameter(torch.unsqueeze(torch.diag(1.0/std_src.data),2)) model.scale_in.bias = torch.nn.Parameter(-(mean_src.data/std_src.data)) for param in model.parameters(): param.requires_grad = True for param in model.scale_in.parameters(): param.requires_grad = False parameters = filter(lambda p: p.requires_grad, model.parameters()) parameters = sum([np.prod(p.size()) for p in parameters]) / 1000000 logging.info('Trainable Parameters: %.3f million' % parameters) module_list = list(model.conv_aux.parameters()) module_list += list(model.upsampling.parameters()) if model.aux_conv2d_flag and model.seg > 1: module_list += list(model.aux_conv2d.parameters()) if model.wav_conv_flag: module_list += list(model.wav_conv.parameters()) module_list += list(model.causal.parameters()) + list(model.in_x.parameters()) module_list += list(model.dil_h.parameters()) + list(model.out_skip.parameters()) module_list += list(model.out_1.parameters()) + list(model.out_2.parameters()) optimizer = torch.optim.Adam(module_list, lr=args.lr) # resume if args.pretrained is not None: checkpoint = torch.load(args.pretrained) model.load_state_dict(checkpoint["model"]) epoch_idx = checkpoint["iterations"] logging.info("pretrained from %d-iter checkpoint." % epoch_idx) epoch_idx = 0 elif args.resume is not None: checkpoint = torch.load(args.resume) model.load_state_dict(checkpoint["model"]) optimizer.load_state_dict(checkpoint["optimizer"]) epoch_idx = checkpoint["iterations"] logging.info("restored from %d-iter checkpoint." % epoch_idx) else: epoch_idx = 0 # define generator training if os.path.isdir(args.waveforms): filenames = sorted(find_files(args.waveforms, "*.wav", use_dir_name=False)) wav_list = [args.waveforms + "/" + filename for filename in filenames] feat_list = [args.feats + "/" + filename.replace(".wav", ".h5") for filename in filenames] elif os.path.isfile(args.waveforms): wav_list = read_txt(args.waveforms) feat_list = read_txt(args.feats) else: logging.error("--waveforms should be directory or list.") sys.exit(1) assert len(wav_list) == len(feat_list) logging.info("number of training data = %d." % len(wav_list)) if args.pretrained is None: generator = train_generator( wav_list, feat_list, model.receptive_field, string_path=args.string_path, seg=model.seg, batch_size=args.batch_size, training=True, upsampling_factor=args.upsampling_factor) else: generator = train_generator( wav_list, feat_list, model.receptive_field, string_path=args.string_path, seg=model.seg, batch_size=args.batch_size, training=True, upsampling_factor=args.upsampling_factor) # define generator evaluation if os.path.isdir(args.waveforms_eval): filenames_eval = sorted(find_files(args.waveforms_eval, "*.wav", use_dir_name=False)) wav_list_eval = [args.waveforms_eval + "/" + filename for filename in filenames_eval] feat_list_eval = [args.feats_eval + "/" + filename.replace(".wav", ".h5") \ for filename in filenames_eval] elif os.path.isfile(args.waveforms_eval): wav_list_eval = read_txt(args.waveforms_eval) feat_list_eval = read_txt(args.feats_eval) else: logging.error("--waveforms_eval should be directory or list.") sys.exit(1) logging.info("number of evaluation data = %d." % len(wav_list_eval)) assert len(wav_list_eval) == len(feat_list_eval) if args.pretrained is None: generator_eval = train_generator( wav_list_eval, feat_list_eval, model.receptive_field, string_path=args.string_path, seg=model.seg, batch_size=args.batch_size, training=False, upsampling_factor=args.upsampling_factor) else: generator_eval = train_generator( wav_list_eval, feat_list_eval, model.receptive_field, string_path=args.string_path, seg=model.seg, batch_size=args.batch_size, training=False, upsampling_factor=args.upsampling_factor) # train loss_laplace = [] loss_err = [] loss_lsd = [] fft_facts = [] init_fft = 64 hann_win = [None]*args.n_fft_facts if args.n_fft_facts == 5: fft_facts = [128, 256, 512, 1024, 2048] for i in range(args.n_fft_facts): hann_win[i] = torch.hann_window(fft_facts[i]).cuda() elif args.n_fft_facts == 9: fft_facts = [128, 192, 256, 384, 512, 768, 1024, 1536, 2048] for i in range(args.n_fft_facts): hann_win[i] = torch.hann_window(fft_facts[i]).cuda() elif args.n_fft_facts == 17: fft_facts = [128, 160, 192, 224, 256, 320, 384, 448, 512, 640, 768, 896, 1024, 1280, 1536, 1792, 2048] for i in range(args.n_fft_facts): hann_win[i] = torch.hann_window(fft_facts[i]).cuda() else: for i in range(args.n_fft_facts): if i % 2 == 0: init_fft *= 2 fft_facts.append(init_fft) else: fft_facts.append(init_fft+int(init_fft/2)) hann_win[i] = torch.hann_window(fft_facts[i]).cuda() logging.info(fft_facts) batch_stft_loss = [None]*args.n_fft_facts stft_out = [None]*args.n_fft_facts stft_trg = [None]*args.n_fft_facts total = 0 iter_idx = 0 iter_count = 0 min_eval_loss_lsd = 99999999.99 min_eval_loss_laplace = 99999999.99 min_eval_loss_err = 99999999.99 min_eval_loss_lsd_std = 99999999.99 min_eval_loss_laplace_std = 99999999.99 min_eval_loss_err_std = 99999999.99 min_idx = -1 if args.resume is not None: np.random.set_state(checkpoint["numpy_random_state"]) torch.set_rng_state(checkpoint["torch_random_state"]) logging.info("==%d EPOCH==" % (epoch_idx+1)) logging.info("Training data") #args.epoch_count = 5300 while epoch_idx < args.epoch_count: start = time.time() batch_x_float, batch_h, c_idx, utt_idx, wavfile, h_bs, x_bs, h_ss, x_ss = next(generator) if c_idx < 0: # summarize epoch # save current epoch model numpy_random_state = np.random.get_state() torch_random_state = torch.get_rng_state() save_checkpoint(args.expdir, model, optimizer, numpy_random_state, torch_random_state, epoch_idx+1) # report current epoch logging.info("(EPOCH:%d) average training loss = %.6f (+- %.6f) %.6f dB (+- %.6f dB) %.6f "\ "(+- %.6f) (%.3f min., %.3f sec / batch)" % ( epoch_idx + 1, np.mean(loss_laplace), np.std(loss_laplace), np.mean(loss_lsd), \ np.std(loss_lsd), np.mean(loss_err), np.std(loss_err), total / 60.0, total / iter_count)) logging.info("estimated training required time = {0.days:02}:{0.hours:02}:{0.minutes:02}:"\ "{0.seconds:02}".format(relativedelta(seconds=int((args.epoch_count - (epoch_idx + 1)) * total)))) # compute loss in evaluation data loss_lsd = [] loss_err = [] loss_laplace = [] total = 0 iter_count = 0 model.eval() for param in model.parameters(): param.requires_grad = False logging.info("Evaluation data") while True: with torch.no_grad(): start = time.time() batch_x_float, batch_h, c_idx, utt_idx, wavfile, h_bs, x_bs, h_ss, x_ss = \ next(generator_eval) if c_idx < 0: break tf = batch_h.shape[0] ts = batch_x_float.shape[0] batch_h = batch_h[h_ss:] batch_x_ = batch_x_float[x_ss:] if model.lpc > 0: if x_ss+model.lpc_offset >= 0: batch_x_lpc = batch_x_float[x_ss+model.lpc_offset:] else: batch_x_lpc = batch_x_float[x_ss:] if h_bs != -1: batch_h = batch_h[:h_bs] if model.lpc > 0: if x_ss+model.lpc_offset >= 0: batch_x_prob = batch_x_lpc[:x_bs-model.lpc_offset].unsqueeze(0) else: batch_x_prob = F.pad(batch_x_lpc[:x_bs], (-(x_ss+model.lpc_offset), 0), \ 'constant', 0).unsqueeze(0) batch_x = batch_x_[:x_bs-model.seg] batch_x_float = batch_x_[model.seg:x_bs] else: if model.lpc > 0: if x_ss+model.lpc_offset > 0: batch_x_prob = batch_x_lpc.unsqueeze(0) else: batch_x_prob = F.pad(batch_x_lpc, (-(x_ss+model.lpc_offset)), \ 'constant', 0).unsqueeze(0) batch_x = batch_x_[:-model.seg] batch_x_float = batch_x_[model.seg:] batch_h = batch_h.transpose(0,1).unsqueeze(0) batch_x = batch_x.unsqueeze(0).unsqueeze(1) if h_ss > 0: feat_len = batch_x_float[model.receptive_field:].shape[0] else: feat_len = batch_x_float.shape[0] if model.lpc > 0: mus, bs, log_bs, ass = model(batch_h, batch_x) # jump off s samples as in synthesis mus = mus[:,::model.seg,:] bs = bs[:,::model.seg,:] log_bs = log_bs[:,::model.seg,:] ass = ass[:,::model.seg,:].flip(-1) init_mus = mus for j in range(model.seg): tmp_smpls = batch_x_prob[:,j:-(model.seg-j)].unfold(1, model.lpc, model.seg) lpc = torch.sum(ass*tmp_smpls,-1,keepdim=True) if j > 0: mus = torch.cat((mus, lpc+init_mus[:,:,j:j+1]),2) else: mus = lpc+init_mus[:,:,j:j+1] mus = mus.reshape(mus.shape[0],-1) bs = bs.reshape(bs.shape[0],-1) log_bs = log_bs.reshape(log_bs.shape[0],-1) else: mus, bs, log_bs = model(batch_h, batch_x) if h_ss > 0: mus = mus[0,model.receptive_field:] bs = bs[0,model.receptive_field:] log_bs = log_bs[0,model.receptive_field:] batch_x_float = batch_x_float[model.receptive_field:] else: mus = mus[0] bs = bs[0] log_bs = log_bs[0] m_sum = 0 batch_loss_laplace = criterion_laplace(mus, bs, batch_x_float, log_b=log_bs) eps = torch.empty(mus.shape).cuda().uniform_(-0.4999,0.5) batch_output = mus-bs*eps.sign()*torch.log1p(-2*eps.abs()) batch_loss_err = torch.mean(torch.abs(batch_output-batch_x_float)) logging.info("%lf %E %lf %E" % (torch.min(batch_x_float), torch.mean(batch_x_float), \ torch.max(batch_x_float), torch.var(batch_x_float))) logging.info("%lf %E %lf %E" % (torch.min(batch_output), torch.mean(batch_output), \ torch.max(batch_output), torch.var(batch_output))) m = 0 for i in range(args.n_fft_facts): if feat_len > int(fft_facts[i]/2): stft_out[i] = torch.stft(batch_output, fft_facts[i], window=hann_win[i]) stft_trg[i] = torch.stft(batch_x_float, fft_facts[i], window=hann_win[i]) tmp_batch_stft_loss = criterion_lsd(stft_out[i], stft_trg[i]) if not torch.isinf(tmp_batch_stft_loss) and not torch.isnan(tmp_batch_stft_loss): if m > 0: batch_loss_lsd = torch.cat((batch_loss_lsd, \ tmp_batch_stft_loss.unsqueeze(0))) else: batch_loss_lsd = tmp_batch_stft_loss.unsqueeze(0) m += 1 loss_err.append(batch_loss_err.item()) loss_laplace.append(batch_loss_laplace.item()) if m > 0: batch_loss_lsd = torch.mean(batch_loss_lsd) loss_lsd.append(batch_loss_lsd.item()) logging.info("batch eval loss %s [%d:%d] %d %d %d %d %d %d = %.3f %.3f dB %.6f "\ "(%.3f sec)" % (os.path.join(os.path.basename(os.path.dirname(wavfile)),\ os.path.basename(wavfile)), c_idx+1, utt_idx+1, tf, ts, h_ss, h_bs, x_ss, x_bs, \ batch_loss_laplace.item(), batch_loss_lsd.item(), \ batch_loss_err.item(), time.time() - start)) else: logging.info("batch eval loss %s [%d:%d] %d %d %d %d %d %d = %.3f n/a %.6f "\ "(%.3f sec)" % (os.path.join(os.path.basename(os.path.dirname(wavfile)),\ os.path.basename(wavfile)), c_idx+1, utt_idx+1, tf, ts, h_ss, h_bs, x_ss, x_bs, \ batch_loss_laplace.item(), batch_loss_err.item(), time.time() - start)) iter_count += 1 total += time.time() - start eval_loss_lsd = np.mean(loss_lsd) eval_loss_lsd_std = np.std(loss_lsd) eval_loss_err = np.mean(loss_err) eval_loss_err_std = np.std(loss_err) eval_loss_laplace = np.mean(loss_laplace) eval_loss_laplace_std = np.std(loss_laplace) logging.info("(EPOCH:%d) average evaluation loss = %.6f (+- %.6f) %.6f dB (+- %.6f dB) %.6f "\ "(+- %.6f) (%.3f min., %.3f sec / batch)" % (epoch_idx + 1, eval_loss_laplace, \ eval_loss_laplace_std, eval_loss_lsd, eval_loss_lsd_std, eval_loss_err, eval_loss_err_std, \ total / 60.0, total / iter_count)) if (eval_loss_laplace+eval_loss_laplace_std+eval_loss_lsd+eval_loss_lsd_std+eval_loss_err\ +eval_loss_err_std) <= (min_eval_loss_laplace+min_eval_loss_laplace_std+min_eval_loss_lsd\ +min_eval_loss_lsd_std+min_eval_loss_err+min_eval_loss_err_std): min_eval_loss_lsd = eval_loss_lsd min_eval_loss_lsd_std = eval_loss_lsd_std min_eval_loss_err = eval_loss_err min_eval_loss_err_std = eval_loss_err_std min_eval_loss_laplace = eval_loss_laplace min_eval_loss_laplace_std = eval_loss_laplace_std min_idx = epoch_idx logging.info("min_eval_loss = %.6f (+- %.6f) %.6f dB (+- %.6f dB) %.6f (+- %.6f) min_idx=%d" % ( min_eval_loss_laplace, min_eval_loss_laplace_std, min_eval_loss_lsd, min_eval_loss_lsd_std, \ min_eval_loss_err, min_eval_loss_err_std, min_idx+1)) loss_lsd = [] loss_laplace = [] loss_err = [] total = 0 iter_count = 0 epoch_idx += 1 np.random.set_state(numpy_random_state) torch.set_rng_state(torch_random_state) model.train() for param in model.parameters(): param.requires_grad = True for param in model.scale_in.parameters(): param.requires_grad = False # start next epoch if epoch_idx < args.epoch_count: start = time.time() logging.info("==%d EPOCH==" % (epoch_idx+1)) logging.info("Training data") batch_x_float, batch_h, c_idx, utt_idx, wavfile, h_bs, x_bs, h_ss, x_ss = next(generator) # feedforward and backpropagate current batch if epoch_idx < args.epoch_count: logging.info("%d iteration [%d]" % (iter_idx+1, epoch_idx+1)) tf = batch_h.shape[0] ts = batch_x_float.shape[0] batch_h = batch_h[h_ss:] batch_x_ = batch_x_float[x_ss:] if model.lpc > 0: if x_ss+model.lpc_offset >= 0: batch_x_lpc = batch_x_float[x_ss+model.lpc_offset:] else: batch_x_lpc = batch_x_float[x_ss:] if h_bs != -1: batch_h = batch_h[:h_bs] if model.lpc > 0: if x_ss+model.lpc_offset >= 0: batch_x_prob = batch_x_lpc[:x_bs-model.lpc_offset].unsqueeze(0) else: batch_x_prob = F.pad(batch_x_lpc[:x_bs], (-(x_ss+model.lpc_offset), 0), \ 'constant', 0).unsqueeze(0) batch_x = batch_x_[:x_bs-model.seg] batch_x_float = batch_x_[model.seg:x_bs] else: if model.lpc > 0: if x_ss+model.lpc_offset > 0: batch_x_prob = batch_x_lpc.unsqueeze(0) else: batch_x_prob = F.pad(batch_x_lpc, (-(x_ss+model.lpc_offset)), \ 'constant', 0).unsqueeze(0) batch_x = batch_x_[:-model.seg] batch_x_float = batch_x_[model.seg:] batch_h = batch_h.transpose(0,1).unsqueeze(0) batch_x = batch_x.unsqueeze(0).unsqueeze(1) if h_ss > 0: if model.seg > 1: feat_len = batch_x_float[model.receptive_field:-(model.seg-1)].shape[0] else: feat_len = batch_x_float[model.receptive_field:].shape[0] else: if model.seg > 1: feat_len = batch_x_float[:-(model.seg-1)].shape[0] else: feat_len = batch_x_float.shape[0] if model.lpc > 0: mus, bs_noclip, bs, log_bs, ass = model(batch_h, batch_x, do=True, clip=True) ass = ass.flip(-1) init_mus = mus for j in range(model.seg): tmp_smpls = batch_x_prob[:,j:-(model.seg-j)].unfold(1, model.lpc, 1) lpc = torch.sum(ass*tmp_smpls,-1,keepdim=True) if j > 0: mus = torch.cat((mus, lpc+init_mus[:,:,j:j+1]),2) else: mus = lpc+init_mus[:,:,j:j+1] if model.seg == 1: mus = mus.reshape(mus.shape[0], -1) bs_noclip = bs_noclip.reshape(mus.shape[0], -1) bs = bs.reshape(mus.shape[0], -1) log_bs = log_bs.reshape(mus.shape[0], -1) else: mus, bs_noclip, bs, log_bs = model(batch_h, batch_x, do=True, clip=True) if h_ss > 0: mus = mus[0,model.receptive_field:] bs_noclip = bs_noclip[0,model.receptive_field:] bs = bs[0,model.receptive_field:] log_bs = log_bs[0,model.receptive_field:] batch_x_float = batch_x_float[model.receptive_field:] else: mus = mus[0] bs_noclip = bs_noclip[0] bs = bs[0] log_bs = log_bs[0] m_sum = 0 if model.seg > 1: n_sum = 0 for i in range(model.seg): if i > 0: i_n = i+1 mus_i = mus[:,i:i_n].squeeze(-1) bs_noclip_i = bs_noclip[:,i:i_n].squeeze(-1) if i_n < model.seg: batch_x_float_i = batch_x_float[i:-(model.seg-(i_n))] else: batch_x_float_i = batch_x_float[i:] tmp_batch_loss_laplace = criterion_laplace(mus_i, bs[:,i:i_n].squeeze(-1), \ batch_x_float_i, log_b=log_bs[:,i:i_n].squeeze(-1), log=False) batch_loss_laplace = torch.cat((batch_loss_laplace, \ tmp_batch_loss_laplace.unsqueeze(0))) else: mus_i = mus[:,:1].squeeze(-1) bs_noclip_i = bs_noclip[:,:1].squeeze(-1) batch_x_float_i = batch_x_float[:-(model.seg-1)] tmp_batch_loss_laplace = criterion_laplace(mus_i, bs[:,:1].squeeze(-1), \ batch_x_float_i, log_b=log_bs[:,:1].squeeze(-1)) batch_loss_laplace = tmp_batch_loss_laplace.unsqueeze(0) eps = torch.empty(mus_i.shape).cuda().uniform_(-0.4999,0.5) batch_output = mus_i-bs_noclip_i*eps.sign()*torch.log1p(-2*eps.abs()) tmp_batch_loss_err = torch.mean(torch.abs(batch_output-batch_x_float_i)) if i > 0: batch_loss_err = torch.cat((batch_loss_err, tmp_batch_loss_err.unsqueeze(0))) else: batch_loss_err = tmp_batch_loss_err.unsqueeze(0) if i == 0: logging.info("%lf %E %lf %E" % (torch.min(batch_x_float_i), \ torch.mean(batch_x_float_i), torch.max(batch_x_float_i), torch.var(batch_x_float_i))) logging.info("%lf %E %lf %E" % (torch.min(batch_output), torch.mean(batch_output), \ torch.max(batch_output), torch.var(batch_output))) n = 0 for i in range(args.n_fft_facts): if feat_len > int(fft_facts[i]/2): stft_out[i] = torch.stft(batch_output, fft_facts[i], window=hann_win[i]) stft_trg[i] = torch.stft(batch_x_float_i, fft_facts[i], window=hann_win[i]) tmp_batch_stft_loss = criterion_lsd(stft_out[i], stft_trg[i], LSD=False, L2=False) if not torch.isinf(tmp_batch_stft_loss) and not torch.isnan(tmp_batch_stft_loss): if n > 0: tmp_batch_loss_stft_l1 = torch.cat((tmp_batch_loss_stft_l1, \ tmp_batch_stft_loss.unsqueeze(0))) else: tmp_batch_loss_stft_l1 = tmp_batch_stft_loss.unsqueeze(0) n += 1 if n > 0: if n_sum > 0: batch_loss_stft_l1 = torch.cat((batch_loss_stft_l1, \ torch.mean(tmp_batch_loss_stft_l1).unsqueeze(0))) else: batch_loss_stft_l1 = torch.mean(tmp_batch_loss_stft_l1).unsqueeze(0) n_sum += n m = 0 for i in range(args.n_fft_facts): if feat_len > int(fft_facts[i]/2): tmp_batch_stft_loss = criterion_lsd(stft_out[i], stft_trg[i]) if not torch.isinf(tmp_batch_stft_loss) and not torch.isnan(tmp_batch_stft_loss): if m > 0: tmp_batch_loss_lsd = torch.cat((tmp_batch_loss_lsd, \ tmp_batch_stft_loss.unsqueeze(0))) else: tmp_batch_loss_lsd = tmp_batch_stft_loss.unsqueeze(0) m += 1 if m > 0: if m_sum > 0: batch_loss_lsd = torch.cat((batch_loss_lsd, \ torch.mean(tmp_batch_loss_lsd).unsqueeze(0))) else: batch_loss_lsd = torch.mean(tmp_batch_loss_lsd).unsqueeze(0) m_sum += m batch_loss_laplace = torch.mean(batch_loss_laplace) batch_loss = batch_loss_laplace if n_sum > 0: batch_loss += torch.mean(batch_loss_stft_l1) if m_sum > 0: batch_loss_lsd = torch.mean(batch_loss_lsd) batch_loss_err = torch.mean(batch_loss_err) else: batch_loss_laplace = criterion_laplace(mus, bs, batch_x_float, log_b=log_bs) batch_loss = batch_loss_laplace eps = torch.empty(mus.shape).cuda().uniform_(-0.4999,0.5) batch_output = mus-bs_noclip*eps.sign()*torch.log1p(-2*eps.abs()) batch_loss_err = torch.mean(torch.abs(batch_output-batch_x_float)) logging.info("%lf %E %lf %E" % (torch.min(batch_x_float), torch.mean(batch_x_float), \ torch.max(batch_x_float), torch.var(batch_x_float))) logging.info("%lf %E %lf %E" % (torch.min(batch_output), torch.mean(batch_output), \ torch.max(batch_output), torch.var(batch_output))) n = 0 for i in range(args.n_fft_facts): if feat_len > int(fft_facts[i]/2): stft_out[i] = torch.stft(batch_output, fft_facts[i], window=hann_win[i]) stft_trg[i] = torch.stft(batch_x_float, fft_facts[i], window=hann_win[i]) tmp_batch_stft_loss = criterion_lsd(stft_out[i], stft_trg[i], LSD=False, L2=False) if not torch.isinf(tmp_batch_stft_loss) and not torch.isnan(tmp_batch_stft_loss): if n > 0: batch_loss_stft_l1 = torch.cat((batch_loss_stft_l1, \ tmp_batch_stft_loss.unsqueeze(0))) else: batch_loss_stft_l1 = tmp_batch_stft_loss.unsqueeze(0) n += 1 if n > 0: batch_loss += torch.mean(batch_loss_stft_l1) m = 0 for i in range(args.n_fft_facts): if feat_len > int(fft_facts[i]/2): tmp_batch_stft_loss = criterion_lsd(stft_out[i], stft_trg[i]) if not torch.isinf(tmp_batch_stft_loss) and not torch.isnan(tmp_batch_stft_loss): if m > 0: batch_loss_lsd = torch.cat((batch_loss_lsd, tmp_batch_stft_loss.unsqueeze(0))) else: batch_loss_lsd = tmp_batch_stft_loss.unsqueeze(0) m += 1 if m > 0: batch_loss_lsd = torch.mean(batch_loss_lsd) optimizer.zero_grad() batch_loss.backward() optimizer.step() loss_err.append(batch_loss_err.item()) loss_laplace.append(batch_loss_laplace.item()) if (model.seg > 1 and m_sum > 0) or (model.seg == 1 and m > 0): loss_lsd.append(batch_loss_lsd.item()) logging.info("batch loss %s [%d:%d] %d %d %d %d %d %d = %.3f %.3f dB %.6f (%.3f sec)" % ( os.path.basename(os.path.dirname(wavfile))+"/"+os.path.basename(wavfile), c_idx+1, \ utt_idx+1, tf, ts, h_ss, h_bs, x_ss, x_bs, batch_loss_laplace.item(), \ batch_loss_lsd.item(), batch_loss_err.item(), time.time() - start)) else: logging.info("batch loss %s [%d:%d] %d %d %d %d %d %d = %.3f n/a %.6f (%.3f sec)" % ( os.path.basename(os.path.dirname(wavfile))+"/"+os.path.basename(wavfile), c_idx+1, \ utt_idx+1, tf, ts, h_ss, h_bs, x_ss, x_bs, batch_loss_laplace.item(), \ batch_loss_err.item(), time.time() - start)) iter_idx += 1 iter_count += 1 total += time.time() - start # save final model model.cpu() torch.save({"model": model.state_dict()}, args.expdir + "/checkpoint-final.pkl") logging.info("final checkpoint created.")
def __enter__(self): self.old_rng_state = torch.get_rng_state() torch.manual_seed(self.rng_seed)
def record_rng(self, *args): self.cpu_state = torch.get_rng_state() if torch.cuda._initialized: self.cuda_in_fwd = True self.gpu_devices, self.gpu_states = get_device_states(*args)
def random_splits_mask_class( dataset, num_train_per_class=20, num_val_per_class=30, num_val=None, num_test=None, seed=None, ): r"""If the data has masks for train/val/test, return the splits with specific number of samples from every class for training as suggested in Pitfalls of graph neural network evaluation [1]_ for semi-supervised learning. References ---------- .. [1] Shchur, O., Mumme, M., Bojchevski, A., & Günnemann, S. (2018). Pitfalls of graph neural network evaluation. arXiv preprint arXiv:1811.05868. Parameters ---------- num_train_per_class : int the number of samples from every class used for training. num_val_per_class : int the number of samples from every class used for validation. num_val : int the total number of nodes that used for validation as alternative. num_test : int the total number of nodes that used for testing as alternative. The rest of the data will be seleted as test set if num_test set to None. seed : int random seed for splitting dataset. """ data = dataset[0] r_s = torch.get_rng_state() if torch.cuda.is_available(): r_s_cuda = torch.cuda.get_rng_state() if seed is not None: torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed(seed) num_classes = data.y.max().cpu().item() + 1 try: data.train_mask.fill_(False) data.val_mask.fill_(False) data.test_mask.fill_(False) except: train_mask = torch.zeros(data.num_nodes, dtype=torch.bool, device=data.edge_index.device) val_mask = torch.zeros(data.num_nodes, dtype=torch.bool, device=data.edge_index.device) test_mask = torch.zeros(data.num_nodes, dtype=torch.bool, device=data.edge_index.device) setattr(data, "train_mask", train_mask) setattr(data, "val_mask", val_mask) setattr(data, "test_mask", test_mask) for c_i in range(num_classes): idx = (data.y == c_i).nonzero().view(-1) assert num_train_per_class + num_val_per_class < idx.size(0), ( "the total number of samples from every class used for training and validation is larger than the total samples in class " + str(c_i)) idx_idx_rand = torch.randperm(idx.size(0)) idx_train = idx[idx_idx_rand[:num_train_per_class]] idx_val = idx[idx_idx_rand[num_train_per_class:num_train_per_class + num_val_per_class]] data.train_mask[idx_train] = True data.val_mask[idx_val] = True if num_val is not None: remaining = (~data.train_mask).nonzero().view(-1) remaining = remaining[torch.randperm(remaining.size(0))] data.val_mask[remaining[:num_val]] = True if num_test is not None: data.test_mask[remaining[num_val:num_val + num_test]] = True else: data.test_mask[remaining[num_val:]] = True else: remaining = (~(data.train_mask + data.val_mask)).nonzero().view(-1) data.test_mask[remaining] = True torch.set_rng_state(r_s) if torch.cuda.is_available(): torch.cuda.set_rng_state(r_s_cuda) dataset.data, dataset.slices = dataset.collate([d for d in dataset]) # while type(dataset.data.num_nodes) == list: # dataset.data.num_nodes = dataset.data.num_nodes[0] # dataset.data.num_nodes = dataset.data.num_nodes[0] return dataset
##### load data ##### graph_list = pickle.load(open(processed + dataset_name + 'networkx_graphs.pkl', 'rb')) data_len = len(graph_list) ##### training parameters ##### epoch_samples = [k for k in [0, 24, 49, 75, 99, 124, 149, 174, 199] if k < max_epoch] bs = {'DHFR/': 11, 'MUTAG/': 10, 'COX2/': 9, 'IMDB-BINARY/': 18, 'NCI1/': 20, 'IMDB-MULTI/': 27 } batch_size = bs[dataset_name] test_size = data_len // 10 train_batches = np.ceil((data_len-test_size)/batch_size).astype(int) print('number of batches = ', train_batches, ' batch size = ', batch_size) torch.manual_seed(0) rng_state= torch.get_rng_state() #seed init to ensure same initial conditions for each training ### eigenvalue path signature #### if xtra_feat: pslevel = 4 sig_prep = iisignature.prepare(2, pslevel) #xtra_feat_length = iisignature.logsiglength(2, pslevel) siglength = iisignature.logsiglength(2, pslevel) xtra_feat_length = siglength if xxtra: siglength += 4 else: xtra_feat_length = 0 print(xtra_feat_length)
def train(self, net, optimizer, resume=False, scheduler=None): logging.info('=' * 80) logging.info('Start training') self.log_datetime() logging.info('=' * 80) train_set = self.get_train_set() test_sets = self.get_test_sets() net = net.to(self.train_device) epoch = 0 min_err = {ts.name: 1e9 for ts in test_sets} state_path = self.exp_out_root / 'state.dict' if resume and state_path.exists(): logging.info('=' * 80) logging.info(f'Loading state from {state_path}') logging.info('=' * 80) state = torch.load(str(state_path)) epoch = state['epoch'] + 1 if 'min_err' in state: min_err = state['min_err'] curr_state = net.state_dict() curr_state.update(state['state_dict']) net.load_state_dict(curr_state) try: optimizer.load_state_dict(state['optimizer']) except: logging.info('Warning: cannot load optimizer from state_dict') pass if 'cpu_rng_state' in state: torch.set_rng_state(state['cpu_rng_state']) if 'gpu_rng_state' in state: torch.cuda.set_rng_state(state['gpu_rng_state']) for epoch in range(epoch, self.epochs): self.callback_train_new_epoch(epoch, net, optimizer) # train epoch self.train_epoch(epoch, net, optimizer, train_set) # test epoch errs = self.test(epoch, net, test_sets) if (epoch + 1) % self.save_frequency == 0: net = net.to(self.train_device) # store state state_dict = { 'epoch': epoch, 'min_err': min_err, 'state_dict': net.state_dict(), 'optimizer': optimizer.state_dict(), 'cpu_rng_state': torch.get_rng_state(), 'gpu_rng_state': torch.cuda.get_rng_state(), } logging.info(f'save state to {state_path}') state_path = self.exp_out_root / 'state.dict' torch.save(state_dict, str(state_path)) for test_set_name in errs: err = sum(errs[test_set_name]) if err < min_err[test_set_name]: min_err[test_set_name] = err state_path = self.exp_out_root / f'state_set{test_set_name}_best.dict' logging.info(f'save state to {state_path}') torch.save(state_dict, str(state_path)) # store network net_path = self.get_net_path(epoch) logging.info(f'save network to {net_path}') torch.save(net.state_dict(), str(net_path)) if scheduler is not None: scheduler.step() logging.info('=' * 80) logging.info('Finished training') self.log_datetime() logging.info('=' * 80)
def run(opts): # start time start_time = time() train_run = [] opts.save_hrs.sort() run_name = opts.run_name # Pretty print the run args pp.pprint(vars(opts)) # Set the random seed torch.manual_seed(opts.seed) # Optionally configure tensorboard tb_logger = None if not opts.no_tensorboard: tb_logger = TbLogger( os.path.join(opts.log_dir, "{}_{}".format(opts.problem, opts.graph_size), opts.run_name)) os.makedirs(opts.save_dir) # Save arguments so exact configuration can always be found with open(os.path.join(opts.save_dir, "args.json"), 'w') as f: json.dump(vars(opts), f, indent=True) # Set the device opts.device = torch.device("cuda:0" if opts.use_cuda else "cpu") # Figure out what's the problem problem = load_problem(opts.problem) # Load data from load_path load_data = {} assert opts.load_path is None or opts.resume is None, "Only one of load path and resume can be given" load_path = opts.load_path if opts.load_path is not None else opts.resume if load_path is not None: print(' [*] Loading data from {}'.format(load_path)) load_data = torch_load_cpu(load_path) # hyperparameter search # default (user specified) config config_defaults = { 'batch_size': opts.batch_size, 'lr_model': opts.lr_model, 'lr_critic': opts.lr_critic, 'lr_decay': opts.lr_decay, } # determine the parameter space """sweep_config = { 'parameters': { 'batch_size': { 'values': [256, 128, 64, 32] }, 'lr_model': { 'values': [1e-2, 1e-3, 1e-4, 3e-4, 3e-5, 1e-5] }, 'lr_critic': { 'values': [1e-2, 1e-3, 1e-4, 3e-4, 3e-5, 1e-5] }, 'lr_decay': { 'lr_decay': [0.9, 0.95, 1.0, 1.05, 1.1, 1.15] }, } }""" # initialize the sweep # sweep_id = wandb.sweep(sweep_config, project="Pytorch-sweeps") # Initialize a new wandb run wandb.init(config=config_defaults) # Config is a variable that holds and saves hyperparameters and inputs config = wandb.config # ??? any code for setting up hyperparameters interested should use config.parameter to set instead of opt.parameter # including functions in other files-> pass config to other functions # Initialize model model_class = { 'attention': AttentionModel, 'pointer': PointerNetwork }.get(opts.model, None) assert model_class is not None, "Unknown model: {}".format(model_class) model = model_class(opts.embedding_dim, opts.hidden_dim, problem, n_encode_layers=opts.n_encode_layers, mask_inner=True, mask_logits=True, normalization=opts.normalization, tanh_clipping=opts.tanh_clipping, checkpoint_encoder=opts.checkpoint_encoder, shrink_size=opts.shrink_size).to(opts.device) if opts.use_cuda and torch.cuda.device_count() > 1: model = torch.nn.DataParallel(model) # Overwrite model parameters by parameters to load model_ = get_inner_model(model) model_.load_state_dict({ **model_.state_dict(), **load_data.get('model', {}) }) # Initialize baseline if opts.baseline == 'exponential': baseline = ExponentialBaseline(opts.exp_beta) elif opts.baseline == 'critic' or opts.baseline == 'critic_lstm': assert problem.NAME == 'tsp', "Critic only supported for TSP" baseline = CriticBaseline( (CriticNetworkLSTM(2, opts.embedding_dim, opts.hidden_dim, opts.n_encode_layers, opts.tanh_clipping) if opts.baseline == 'critic_lstm' else CriticNetwork( 2, opts.embedding_dim, opts.hidden_dim, opts.n_encode_layers, opts.normalization)).to(opts.device)) elif opts.baseline == 'rollout': baseline = RolloutBaseline(model, problem, opts) else: assert opts.baseline is None, "Unknown baseline: {}".format( opts.baseline) baseline = NoBaseline() if opts.bl_warmup_epochs > 0: baseline = WarmupBaseline(baseline, opts.bl_warmup_epochs, warmup_exp_beta=opts.exp_beta) # Load baseline from data, make sure script is called with same type of baseline if 'baseline' in load_data: baseline.load_state_dict(load_data['baseline']) # Initialize optimizer optimizer = optim.Adam([{ 'params': model.parameters(), 'lr': config.lr_model }] + ([{ 'params': baseline.get_learnable_parameters(), 'lr': config.lr_critic }] if len(baseline.get_learnable_parameters()) > 0 else [])) # Load optimizer state if 'optimizer' in load_data: optimizer.load_state_dict(load_data['optimizer']) for state in optimizer.state.values(): for k, v in state.items(): # if isinstance(v, torch.Tensor): if torch.is_tensor(v): state[k] = v.to(opts.device) # Initialize learning rate scheduler, decay by lr_decay once per epoch! lr_scheduler = optim.lr_scheduler.LambdaLR( optimizer, lambda epoch: config.lr_decay**epoch) # Start the actual training loop val_dataset = problem.make_dataset(size=opts.graph_size, num_samples=opts.val_size, filename=opts.val_dataset, distribution=opts.data_distribution) if opts.resume: epoch_resume = int( os.path.splitext(os.path.split(opts.resume)[-1])[0].split("-")[1]) torch.set_rng_state(load_data['rng_state']) if opts.use_cuda: torch.cuda.set_rng_state_all(load_data['cuda_rng_state']) # Set the random states # Dumping of state was done before epoch callback, so do that now (model is loaded) baseline.epoch_callback(model, epoch_resume) print("Resuming after {}".format(epoch_resume)) opts.epoch_start = epoch_resume + 1 torch.save(model, os.path.join('.', 'empty.pt')) if opts.eval_only: validate(model, val_dataset, opts) else: for epoch in range(opts.epoch_start, opts.epoch_start + opts.n_epochs): avg_time = train_epoch(model, optimizer, baseline, lr_scheduler, epoch, val_dataset, problem, tb_logger, opts, start_time, config) train_run.append(avg_time) for hr in opts.save_hrs: if (time() - start_time) > hr * 3600: opts.save_hrs.remove(hr) print('Saving model and state...') hr_time = int(round((time() - start_time) / 3600)) with open( '../models/att/hist_{}_{}hr.pickle'.format( run_name, hr_time), 'wb') as handle: pickle.dump(train_run, handle, protocol=pickle.HIGHEST_PROTOCOL) torch.save( { 'model': get_inner_model(model).state_dict(), 'optimizer': optimizer.state_dict(), 'rng_state': torch.get_rng_state(), 'cuda_rng_state': torch.cuda.get_rng_state_all(), 'baseline': baseline.state_dict() }, os.path.join( '../models/att', '{}_{}hr-model-att-only.pt'.format( run_name, hr_time))) torch.save( model, os.path.join( '../models/att', '{}_{}hr-model.pt'.format(run_name, hr_time)))
def test__set_rng_states_cuda(): # Checks https://github.com/pytorch/ignite/issues/2076 rng_states = [random.getstate(), torch.get_rng_state().cuda(), np.random.get_state()] _set_rng_states(rng_states) assert rng_states[1].device.type == "cpu"
if switch_to_rl and not best: data = torch.load( (args.save_folder + '/%s_best.pth') % args.exp_name) torch.set_rng_state(data['torch_rng_state']) torch.cuda.set_rng_state(data['cuda_rng_state']) np.random.set_state(data['numpy_rng_state']) random.setstate(data['random_rng_state']) model.load_state_dict(data['state_dict']) print( 'Resuming from epoch %d, validation loss %f, and best cider %f' % (data['epoch'], data['val_loss'], data['best_cider'])) torch.save( { 'torch_rng_state': torch.get_rng_state(), 'cuda_rng_state': torch.cuda.get_rng_state(), 'numpy_rng_state': np.random.get_state(), 'random_rng_state': random.getstate(), 'epoch': e, 'val_loss': val_loss, 'val_cider': val_cider, 'state_dict': model.state_dict(), 'optimizer': optim.state_dict(), 'scheduler': scheduler.state_dict(), 'patience': patience, 'best_cider': best_cider, 'use_rl': use_rl, }, (args.save_folder + '/%s_last.pth') % args.exp_name) if best:
if opt.saveBest: savePath = os.path.join(dirPath, "model_epoch_best.pth") else: savePath = os.path.join(dirPath, "model_epoch_{}.pth".format(state['epoch'])) th.save(state, savePath) print("===> Checkpoint saved to {}".format(savePath)) epoch_loss = float('inf') for epoch in range(start + 1, opt.nEpochs + 1): if opt.saveBest: epoch_loss_new = train(epoch) else: train(epoch) test() if not epoch % opt.saveFreq: if opt.saveBest and epoch_loss_new < epoch_loss: epoch_loss = epoch_loss_new state = {'epoch':epoch, 'model_state_dict':model.state_dict(),\ 'optimizer_state_dict':optimizer.state_dict(),\ 'rng_state':th.get_rng_state(),'params':params} save_checkpoint(state) elif not opt.saveBest: state = {'epoch':epoch, 'model_state_dict':model.state_dict(),\ 'optimizer_state_dict':optimizer.state_dict(),\ 'rng_state':th.get_rng_state(),'params':params} save_checkpoint(state) print("******************************************************") print("\n ============ Training completed ======================\n")
def setUp(self): if os.getenv('UNLOCK_SEED') is None or os.getenv( 'UNLOCK_SEED').lower() == 'false': self.rng_state = torch.get_rng_state() torch.manual_seed(0)
today = datetime.datetime.now().strftime("%Y%m%d%H%M%S") # Append date and time of the run to the directory, to avoid several runs of # overwritting each other. saveDir = saveDir + today # Create directory if not os.path.exists(saveDir): os.makedirs(saveDir) # Create the file where all the (hyper)parameters and results will be saved. varsFile = os.path.join(saveDir, 'hyperparameters.txt') with open(varsFile, 'w+') as file: file.write('%s\n\n' % datetime.datetime.now().strftime("%Y/%m/%d %H:%M:%S")) # \\\ Save seeds for reproducibility # PyTorch seeds torchState = torch.get_rng_state() torchSeed = torch.initial_seed() # Numpy seeds numpyState = np.random.RandomState().get_state() # Collect all random states randomStates = [] randomStates.append({}) randomStates[0]['module'] = 'numpy' randomStates[0]['state'] = numpyState randomStates.append({}) randomStates[1]['module'] = 'torch' randomStates[1]['state'] = torchState randomStates[1]['seed'] = torchSeed # This list and dictionary follows the format to then be loaded, if needed, # by calling the loadSeed function in Utils.miscTools saveSeed(randomStates, saveDir)
def main(): parser = argparse.ArgumentParser() # path setting parser.add_argument("--waveforms", required=True, type=str, help="directory or list of wav files") parser.add_argument("--waveforms_eval", required=True, type=str, help="directory or list of evaluation wav files") parser.add_argument("--feats", required=True, type=str, help="directory or list of aux feat files") parser.add_argument("--feats_eval", required=True, type=str, help="directory or list of evaluation aux feat files") parser.add_argument("--config", required=True, type=str, help="configure file") parser.add_argument("--pretrained", required=True, type=str, help="pretrained model path") parser.add_argument("--expdir", required=True, type=str, help="directory to save the model") parser.add_argument("--init_acc", default=False, type=strtobool, help="flag to compute accuracy of initial pretrained model") parser.add_argument("--string_path", required=True, type=str, help="flag to compute accuracy of initial pretrained model") # other setting parser.add_argument("--seed", default=1, type=int, help="seed number") parser.add_argument("--resume", default=None, type=str, help="model path to restart training") parser.add_argument("--GPU_device", default=None, type=int, help="selection of GPU device") parser.add_argument("--verbose", default=1, type=int, help="log level") args = parser.parse_args() if args.GPU_device is not None: os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["CUDA_VISIBLE_DEVICES"] = str(args.GPU_device) # make experimental directory if not os.path.exists(args.expdir): os.makedirs(args.expdir) # set log level if args.verbose == 1: logging.basicConfig(level=logging.INFO, format='%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s', datefmt='%m/%d/%Y %I:%M:%S', filename=args.expdir + "/train.log") logging.getLogger().addHandler(logging.StreamHandler()) elif args.verbose > 1: logging.basicConfig(level=logging.DEBUG, format='%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s', datefmt='%m/%d/%Y %I:%M:%S', filename=args.expdir + "/train.log") logging.getLogger().addHandler(logging.StreamHandler()) else: logging.basicConfig(level=logging.WARN, format='%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s', datefmt='%m/%d/%Y %I:%M:%S', filename=args.expdir + "/train.log") logging.getLogger().addHandler(logging.StreamHandler()) logging.warn("logging is disabled.") # fix seed os.environ['PYTHONHASHSEED'] = str(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) torch.backends.cudnn.benchmark = True #faster #torch.backends.cudnn.deterministic = True #reproducibility_slower #torch.backends.cudnn.benchmark = False #reproducibility_slower # load config config = torch.load(args.config) config.expdir = args.expdir config.pretrained = args.pretrained config.string_path_ft = args.string_path # save args as conf if not args.init_acc: torch.save(config, args.expdir + "/model.conf") # # define network model = DSWNV( n_quantize=config.n_quantize, n_aux=config.n_aux, hid_chn=config.hid_chn, skip_chn=config.skip_chn, dilation_depth=config.dilation_depth, dilation_repeat=config.dilation_repeat, kernel_size=config.kernel_size, aux_kernel_size=config.aux_kernel_size, aux_dilation_size=config.aux_dilation_size, do_prob=config.do_prob, upsampling_factor=config.upsampling_factor) logging.info(model) criterion = nn.CrossEntropyLoss() checkpoint = torch.load(args.pretrained) model.load_state_dict(checkpoint["model"]) epoch_idx = checkpoint["iterations"] logging.info("pretrained from %d-iter checkpoint." % epoch_idx) epoch_idx = 0 # send to gpu if torch.cuda.is_available(): model.cuda() criterion.cuda() else: logging.error("gpu is not available. please check the setting.") sys.exit(1) if not args.init_acc: for param in model.parameters(): param.requires_grad = True for param in model.scale_in.parameters(): param.requires_grad = False parameters = filter(lambda p: p.requires_grad, model.parameters()) parameters = sum([np.prod(p.size()) for p in parameters]) / 1000000 logging.info('Trainable Parameters: %.3f million' % parameters) module_list = list(model.conv_aux.parameters())+ list(model.upsampling.parameters()) module_list += list(model.causal.parameters()) + list(model.in_x.parameters()) module_list += list(model.dil_h.parameters()) + list(model.out_skip.parameters()) module_list += list(model.out_1.parameters()) + list(model.out_2.parameters()) optimizer = torch.optim.Adam(module_list, lr=config.lr) model.train() else: for param in model.parameters(): param.requires_grad = False model.eval() # resume if args.resume is not None: checkpoint = torch.load(args.resume) model.load_state_dict(checkpoint["model"]) optimizer.load_state_dict(checkpoint["optimizer"]) epoch_idx = checkpoint["iterations"] logging.info("restored from %d-iter checkpoint." % epoch_idx) wav_transform = transforms.Compose([lambda x: encode_mu_law(x, config.n_quantize)]) # define generator training if os.path.isdir(args.waveforms): filenames = sorted(find_files(args.waveforms, "*.wav", use_dir_name=False)) wav_list = [args.waveforms + "/" + filename for filename in filenames] feat_list = [args.feats + "/" + filename.replace(".wav", ".h5") for filename in filenames] elif os.path.isfile(args.waveforms): wav_list = read_txt(args.waveforms) feat_list = read_txt(args.feats) else: logging.error("--waveforms should be directory or list.") sys.exit(1) assert len(wav_list) == len(feat_list) logging.info("number of training data = %d." % len(wav_list)) generator = train_generator( wav_list, feat_list, receptive_field=model.receptive_field, batch_size=config.batch_size, wav_transform=wav_transform, string_path=config.string_path_ft, training=True, upsampling_factor=config.upsampling_factor) # define generator evaluation if os.path.isdir(args.waveforms_eval): filenames_eval = sorted(find_files(args.waveforms_eval, "*.wav", use_dir_name=False)) wav_list_eval = [args.waveforms_eval + "/" + filename for filename in filenames_eval] feat_list_eval = [args.feats_eval + "/" + filename.replace(".wav", ".h5") \ for filename in filenames_eval] elif os.path.isfile(args.waveforms_eval): wav_list_eval = read_txt(args.waveforms_eval) feat_list_eval = read_txt(args.feats_eval) else: logging.error("--waveforms_eval should be directory or list.") sys.exit(1) logging.info("number of evaluation data = %d." % len(wav_list_eval)) assert len(wav_list_eval) == len(feat_list_eval) generator_eval = train_generator( wav_list_eval, feat_list_eval, receptive_field=model.receptive_field, batch_size=config.batch_size, wav_transform=wav_transform, string_path=config.string_path_ft, training=False, upsampling_factor=config.upsampling_factor) if args.init_acc: epoch_idx = -1 # train loss = [] total = 0 iter_idx = 0 iter_count = 0 min_eval_loss = 99999999.99 min_eval_loss_std = 99999999.99 if args.resume is not None: np.random.set_state(checkpoint["numpy_random_state"]) torch.set_rng_state(checkpoint["torch_random_state"]) logging.info("==%d EPOCH==" % (epoch_idx+1)) logging.info("Training data") while epoch_idx < config.epoch_count: start = time.time() batch_x_class, batch_x, batch_h, c_idx, utt_idx, wavfile, h_bs, x_bs, h_ss, x_ss = next(generator) if c_idx < 0: # summarize epoch # save current epoch model if not args.init_acc: numpy_random_state = np.random.get_state() torch_random_state = torch.get_rng_state() save_checkpoint(args.expdir, model, optimizer, numpy_random_state, \ torch_random_state, epoch_idx + 1) # report current epoch logging.info("(EPOCH:%d) average training loss = %.6f (+- %.6f) (%.3f min., %.3f sec / batch)" % ( epoch_idx + 1, np.mean(np.array(loss, dtype=np.float64)), \ np.std(np.array(loss, dtype=np.float64)), total / 60.0, total / iter_count)) logging.info("estimated training required time = {0.days:02}:{0.hours:02}:{0.minutes:02}:"\ "{0.seconds:02}".format(relativedelta(seconds=int((config.epoch_count - (epoch_idx + 1))*total)))) # compute loss in evaluation data loss = [] total = 0 iter_count = 0 if not args.init_acc: model.eval() for param in model.parameters(): param.requires_grad = False logging.info("Evaluation data") while True: start = time.time() batch_x_class, batch_x, batch_h, c_idx, utt_idx, wavfile, h_bs, x_bs, h_ss, x_ss \ = next(generator_eval) if c_idx < 0: break tf = batch_h.shape[0] ts = batch_x.shape[0] batch_h = batch_h[h_ss:] batch_x_class = batch_x_class[x_ss:] batch_x = batch_x[x_ss:] if h_bs != -1: batch_h = batch_h[:h_bs] batch_x_class = batch_x_class[1:x_bs] batch_x = batch_x[:x_bs-1] else: batch_x = batch_x[:-1] batch_x_class = batch_x_class[1:] batch_h = batch_h.transpose(0,1).unsqueeze(0) batch_x = batch_x.transpose(0,1).unsqueeze(0) batch_output = model(batch_x, batch_h)[0] if h_ss > 0: batch_loss = criterion(batch_output[model.receptive_field:], \ batch_x_class[model.receptive_field:]) else: batch_loss = criterion(batch_output, batch_x_class) loss.append(batch_loss.item()) logging.info("batch eval loss %s [%d:%d] %d %d %d %d %d %d = %.3f (%.3f sec)" % ( os.path.basename(os.path.dirname(wavfile))+"/"+os.path.basename(wavfile), \ c_idx+1, utt_idx+1, tf, ts, h_ss, h_bs, x_ss, x_bs, batch_loss.item(), time.time()-start)) iter_count += 1 total += time.time() - start eval_loss = np.mean(np.array(loss, dtype=np.float64)) eval_loss_std = np.std(np.array(loss, dtype=np.float64)) logging.info("(EPOCH:%d) average evaluation loss = %.6f (+- %.6f) (%.3f min., %.3f sec / batch)"%( epoch_idx + 1, eval_loss, eval_loss_std, total / 60.0, total / iter_count)) if not args.init_acc: if (eval_loss+eval_loss_std) <= (min_eval_loss+min_eval_loss_std): min_eval_loss = eval_loss min_eval_loss_std = eval_loss_std min_idx = epoch_idx logging.info("min_eval_loss=%.6f (+- %.6f), min_idx=%d" % (\ min_eval_loss, min_eval_loss_std, min_idx+1)) loss = [] total = 0 iter_count = 0 epoch_idx += 1 np.random.set_state(numpy_random_state) torch.set_rng_state(torch_random_state) model.train() for param in model.parameters(): param.requires_grad = True for param in model.scale_in.parameters(): param.requires_grad = False else: exit() # start next epoch if epoch_idx < config.epoch_count: start = time.time() logging.info("==%d EPOCH==" % (epoch_idx+1)) logging.info("Training data") batch_x_class, batch_x, batch_h, c_idx, utt_idx, wavfile, h_bs, x_bs, h_ss, x_ss \ = next(generator) # feedforward and backpropagate current batch if epoch_idx < config.epoch_count: logging.info("%d iteration [%d]" % (iter_idx+1, epoch_idx+1)) tf = batch_h.shape[0] ts = batch_x.shape[0] batch_h = batch_h[h_ss:] batch_x_class = batch_x_class[x_ss:] batch_x = batch_x[x_ss:] if h_bs != -1: batch_h = batch_h[:h_bs] batch_x_class = batch_x_class[1:x_bs] batch_x = batch_x[:x_bs-1] else: batch_x = batch_x[:-1] batch_x_class = batch_x_class[1:] batch_h = batch_h.transpose(0,1).unsqueeze(0) batch_x = batch_x.transpose(0,1).unsqueeze(0) if not args.init_acc: batch_output = model(batch_x, batch_h, do=True)[0] else: batch_output = model(batch_x, batch_h)[0] if h_ss > 0: batch_loss = criterion(batch_output[model.receptive_field:], \ batch_x_class[model.receptive_field:]) else: batch_loss = criterion(batch_output, batch_x_class) if not args.init_acc: optimizer.zero_grad() batch_loss.backward() optimizer.step() loss.append(batch_loss.item()) logging.info("batch loss %s [%d:%d] %d %d %d %d %d %d = %.3f (%.3f sec)" % ( os.path.basename(os.path.dirname(wavfile))+"/"+os.path.basename(wavfile), c_idx+1, utt_idx+1, tf, ts, h_ss, h_bs, x_ss, x_bs, batch_loss.item(), time.time() - start)) iter_idx += 1 iter_count += 1 total += time.time() - start # save final model model.cpu() torch.save({"model": model.state_dict()}, args.expdir + "/checkpoint-final.pkl") logging.info("final checkpoint created.")
optimiser.zero_grad() if args.model in ['ANP', 'NP']: y_target_mu, y_target_sigma, loss, log_pred, kl_target_context = model.forward( x_context, y_context, x_target, y_target ) else: y_target_mu, y_target_sigma, loss, _, _ = model.forward(x_context, y_context, x_target, y_target) loss = loss.sum() loss.backward() optimiser.step() if epoch % args.test_epoch == 0: # plot some samples # refix the seed random_state = torch.get_rng_state() for i in range(10): torch.manual_seed(i) # Sample a single GP from the distribution data_test, params = datagen_test.generate_curves() y_context = data_test.query[0][1].contiguous()[0].unsqueeze(0) x_context = data_test.query[0][0].contiguous()[0].unsqueeze(0) x_target = data_test.query[1].contiguous()[0].unsqueeze(0) y_target = data_test.target_y.contiguous()[0].unsqueeze(0) # If our model has a latent part, we want to plot many # samples from the function if args.model in ['ANP', 'NP']: y_target_mu = [] y_target_sigma = [] for j in range(10):