def fit(self, vector_size, eta = 1e-4, epochs = 100, optimiser = 'adagrad', stop = None, tau = 1e-7, **optimiser_kwargs): if isinstance(optimiser, str): optimiser = get_optimiser(optimiser) logger(f'fitting with vector size = {vector_size:,d}') r, c, x = self.r, self.c, self.x # Filter out not frequent enough co-occurances if self.x_min is not None: _r, _c, x = r[self._idx], c[self._idx], x[self._idx] ur = {r : i for i, r in enumerate(np.unique(_r))} uc = {c : i for i, c in enumerate(np.unique(_c))} r = np.array([ur[r] for r in _r]).astype('int32') c = np.array([uc[c] for c in _c]).astype('int32') # Free memory del _r, _c, ur, uc; gc.collect() # Compute max if not set, then cap values x_max = x.max() if self.x_max is None else self.x_max if self.x_max is not None: rprint('setting x_max upper bound') _x = np.minimum(x, x_max) rprint('precomputing f(X)') fx = (_x / x_max) ** self.alpha # Free memory del _x; gc.collect() else: rprint('precomputing f(X)') fx = (x / x_max) ** self.alpha rprint('precomputing log(X)') lx = np.log(x) # Free memory del x; gc.collect() np.random.seed(self.random_state) shape = len(np.unique(r)), vector_size rprint('initialising word vectors and bias vector variables') W1 = np.random.normal(scale = 0.5, size = shape).astype('float32') W2 = np.random.normal(scale = 0.5, size = shape).astype('float32') b1 = np.random.normal(scale = 0.5, size = shape[0]).astype('float32') b2 = np.random.normal(scale = 0.5, size = shape[0]).astype('float32') # As sparse matrix may have multiple entries per row, compute these entries before hand for later ease rprint('computing masks for optimisation') rmasks = {} cmasks = {} for d, masks in zip([r, c], [rmasks, cmasks]): for i, val in enumerate(d): if val not in masks: masks[val] = [] masks[val] += [i] # Free memory (masks is linked to cmasks so cannot delete it) del d; gc.collect() # Initialise optimisers (W1, W2, b) optim = [optimiser(eta = eta, **optimiser_kwargs) for _ in range(3)] logger(f'initialised variables') u = Update('optimising epoch', epochs) L = self.L = np.ones(epochs + 1) * np.inf N = fx.sum() lo = np.inf for i in range(epochs): # Early stopping condition if over the last "stop" iterations there is a total variation of less than "tau" if stop is not None and i >= stop: if (L[i - stop: i].max() / L[i - stop: i].min() - 1) <= tau: break delta = (W1[r] * W2[c]).sum(axis = 1) + b1[r] + b2[c] - lx L[i] = np.mean(fx * np.square(delta)) # Store the best if L[i] < lo: best = [W1.copy(), W2.copy(), b1.copy(), b2.copy()] lo = L[i] # Chain rule of loss function of the form L = fx * (delta ^ 2) w.r.t. delta (ignoring proportional constants) chain = (fx * delta) # Compute gradients to update W and b i.e. differentiate delta w.r.t W and b respectively # # Steps: # • Compute adjusted gradients using optimiser # • Aggregate gradients for each token (row of W) # • Update parameter # • Free space to reduce memory cost # # Do for W1 (optim[0]), W2 (optim[1]), b1 (optim[2]), b2 (optim[2]) # Gradients for b1 and b2 are similar just with different aggregation masks r and c gw1 = optim[0](np.einsum('c,cv->cv', chain, W2[c]).astype('float32')) gW1 = np.zeros_like(W1) for j, mask in rmasks.items(): gW1[j] += gw1[mask].mean(axis = 0) W1 -= gW1 del gw1, gW1; gc.collect() gw2 = optim[1](np.einsum('c,cv->cv', chain, W1[r]).astype('float32')) gW2 = np.zeros_like(W2) for j, mask in cmasks.items(): gW2[j] += gw2[mask].mean(axis = 0) W2 -= gW2 del gw2, gW2; gc.collect() # Common gradients for b1 and b2 with different aggregations gb = optim[2](chain.astype('float32')) gb1 = np.zeros_like(b1) for j, mask in rmasks.items(): gb1[j] += gb[mask].mean(axis = 0) b1 -= gb1 del gb1; gc.collect() gb2 = np.zeros_like(b2) for j, mask in cmasks.items(): gb2[j] += gb[mask].mean(axis = 0) b2 -= gb2 del chain, gb2; gc.collect() # Verbose update u.increment() u.display(loss = L[i], best = lo) else: # Enters the else statement only if the for loop completes without break i += 1 delta = (W1[r] * W2[c]).sum(axis = 1) + b1[r] + b2[c] - lx L[i] = np.sum(fx * np.square(delta)) / N if L[i] == L.min(): best = [W1.copy(), W2.copy(), b1.copy(), b2.copy()] self.W, self.Wc, self.b, self.bc = best self.L = L[:i + 1] logger(f'optimised over {i:,d} epochs (best loss = {min(L):,.3e}, final loss = {L[i]:,.3e})') return self
def finetune(encoder, mlp, dataloaders, args): ''' Finetune script - SimCLR Freeze the encoder and train the supervised classification head with a Cross Entropy Loss. ''' mode = 'finetune' ''' Optimisers ''' # Only optimise the supervised head optimiser = get_optimiser((mlp, ), mode, args) ''' Schedulers ''' # Cosine LR Decay lr_decay = lr_scheduler.CosineAnnealingLR(optimiser, args.finetune_epochs) ''' Loss / Criterion ''' criterion = torch.nn.CrossEntropyLoss().cuda() # initilize Variables args.writer = SummaryWriter(args.summaries_dir) best_valid_loss = np.inf best_valid_acc = 0.0 patience_counter = 0 n_batches = len(dataloaders['train']) print(f"n_batches: {n_batches}") ''' Pretrain loop ''' for epoch in range(args.finetune_epochs): # Freeze the encoder, train classification head encoder.eval() mlp.train() sample_count = 0 run_loss = 0 run_top1 = 0.0 run_top5 = 0.0 # Print setup for distributed only printing on one node. if args.print_progress: logging.info('\nEpoch {}/{}:\n'.format(epoch + 1, args.finetune_epochs)) # tqdm for process (rank) 0 only when using distributed training train_dataloader = tqdm(dataloaders['train']) else: train_dataloader = dataloaders['train'] ''' epoch loop ''' for i, (inputs, target) in enumerate(train_dataloader): inputs = inputs.cuda(non_blocking=True) target = target.cuda(non_blocking=True) # Forward pass optimiser.zero_grad() # Do not compute the gradients for the frozen encoder with torch.no_grad(): h = encoder(inputs) # Take pretrained encoder representations output = mlp(h) loss = criterion(output, target) loss.backward() optimiser.step() torch.cuda.synchronize() sample_count += inputs.size(0) run_loss += loss.item() predicted = output.argmax(1) acc = (predicted == target).sum().item() / target.size(0) run_top1 += acc _, output_topk = output.topk(5, 1, True, True) acc_top5 = (output_topk == target.view( -1, 1).expand_as(output_topk)).sum().item() / target.size( 0) # num corrects run_top5 += acc_top5 epoch_finetune_loss = run_loss / len( dataloaders['train']) # sample_count epoch_finetune_acc = run_top1 / len(dataloaders['train']) epoch_finetune_acc_top5 = run_top5 / len(dataloaders['train']) ''' Update Schedulers ''' # Decay lr with CosineAnnealingLR lr_decay.step() ''' Printing ''' if args.print_progress: # only validate using process 0 logging.info( '\n[Finetune] loss: {:.4f},\t acc: {:.4f}, \t acc_top5: {:.4f}\n' .format(epoch_finetune_loss, epoch_finetune_acc, epoch_finetune_acc_top5)) args.writer.add_scalars('finetune_epoch_loss', {'train': epoch_finetune_loss}, epoch + 1) args.writer.add_scalars('finetune_epoch_acc', {'train': epoch_finetune_acc}, epoch + 1) args.writer.add_scalars('finetune_epoch_acc_top5', {'train': epoch_finetune_acc_top5}, epoch + 1) args.writer.add_scalars('finetune_lr', {'train': optimiser.param_groups[0]['lr']}, epoch + 1) #Log the validation losses. valid_loss, valid_acc, valid_acc_top5 = evaluate( encoder, mlp, dataloaders, 'valid', epoch, args) # For the best performing epoch, reset patience and save model, # else update patience. if valid_acc >= best_valid_acc: patience_counter = 0 best_epoch = epoch + 1 best_valid_acc = valid_acc # saving using process (rank) 0 only as all processes are in sync state = { #'args': args, 'encoder': encoder.state_dict(), 'supp_mlp': mlp.state_dict(), 'optimiser': optimiser.state_dict(), 'epoch': epoch } torch.save(state, (args.checkpoint_dir[:-3] + "_finetune.pt")) else: patience_counter += 1 if patience_counter == (args.patience - 10): logging.info('\nPatience counter {}/{}.'.format( patience_counter, args.patience)) elif patience_counter == args.patience: logging.info( '\nEarly stopping... no improvement after {} Epochs.'. format(args.patience)) break epoch_finetune_loss = None # reset loss epoch_finetune_acc = None epoch_finetune_acc_top5 = None del state torch.cuda.empty_cache() gc.collect() # release unreferenced memory
def pretrain(encoder, mlp, dataloaders, args): ''' Pretrain script - SimCLR Pretrain the encoder and projection head with a Contrastive NT_Xent Loss. ''' mode = 'pretrain' ''' Optimisers ''' optimiser = get_optimiser((encoder, mlp), mode, args) ''' Schedulers ''' # Warmup Scheduler if args.warmup_epochs > 0: for param_group in optimiser.param_groups: param_group['lr'] = (1e-12 / args.warmup_epochs) * args.learning_rate # Cosine LR Decay after the warmup epochs lr_decay = lr_scheduler.CosineAnnealingLR( optimiser, (args.n_epochs - args.warmup_epochs), eta_min=0.0, last_epoch=-1) else: # Cosine LR Decay lr_decay = lr_scheduler.CosineAnnealingLR(optimiser, args.n_epochs, eta_min=0.0, last_epoch=-1) ''' Loss / Criterion ''' criterion = SimclrCriterion(batch_size=args.batch_size, normalize=True, temperature=args.temperature).cuda() # initilize Variables args.writer = SummaryWriter(args.summaries_dir) best_valid_loss = np.inf patience_counter = 0 ''' Pretrain loop ''' for epoch in range(args.n_epochs): # Train models encoder.train() mlp.train() sample_count = 0 run_loss = 0 # Print setup for distributed only printing on one node. if args.print_progress: logging.info('\nEpoch {}/{}:\n'.format(epoch + 1, args.n_epochs)) # tqdm for process (rank) 0 only when using distributed training train_dataloader = tqdm(dataloaders['pretrain']) else: train_dataloader = dataloaders['pretrain'] ''' epoch loop ''' for i, (inputs, _) in enumerate(train_dataloader): inputs = inputs.cuda(non_blocking=True) # Forward pass optimiser.zero_grad() # retrieve the 2 views x_i, x_j = torch.split(inputs, [3, 3], dim=1) # Get the encoder representation h_i = encoder(x_i) h_j = encoder(x_j) # Get the nonlinear transformation of the representation z_i = mlp(h_i) z_j = mlp(h_j) # Calculate NT_Xent loss loss = criterion(z_i, z_j) loss.backward() optimiser.step() torch.cuda.synchronize() sample_count += inputs.size(0) run_loss += loss.item() epoch_pretrain_loss = run_loss / len(dataloaders['pretrain']) ''' Update Schedulers ''' # TODO: Improve / add lr_scheduler for warmup if args.warmup_epochs > 0 and epoch + 1 <= args.warmup_epochs: wu_lr = (float(epoch + 1) / args.warmup_epochs) * args.learning_rate save_lr = optimiser.param_groups[0]['lr'] optimiser.param_groups[0]['lr'] = wu_lr else: # After warmup, decay lr with CosineAnnealingLR lr_decay.step() ''' Printing ''' if args.print_progress: # only validate using process 0 logging.info('\n[Train] loss: {:.4f}'.format(epoch_pretrain_loss)) args.writer.add_scalars('epoch_loss', {'pretrain': epoch_pretrain_loss}, epoch + 1) args.writer.add_scalars( 'lr', {'pretrain': optimiser.param_groups[0]['lr']}, epoch + 1) state = { #'args': args, 'encoder': encoder.state_dict(), 'mlp': mlp.state_dict(), 'optimiser': optimiser.state_dict(), 'epoch': epoch, } torch.save(state, args.checkpoint_dir) # For the best performing epoch, reset patience and save model, # else update patience. if epoch_pretrain_loss <= best_valid_loss: patience_counter = 0 best_epoch = epoch + 1 best_valid_loss = epoch_pretrain_loss else: patience_counter += 1 if patience_counter == (args.patience - 10): logging.info('\nPatience counter {}/{}.'.format( patience_counter, args.patience)) elif patience_counter == args.patience: logging.info( '\nEarly stopping... no improvement after {} Epochs.'. format(args.patience)) break epoch_pretrain_loss = None # reset loss del state torch.cuda.empty_cache() gc.collect() # release unreferenced memory
def supervised(encoder, mlp, dataloaders, args): ''' Supervised Train script - SimCLR Supervised Training encoder and train the supervised classification head with a Cross Entropy Loss. ''' mode = 'pretrain' ''' Optimisers ''' # Only optimise the supervised head optimiser = get_optimiser((encoder, mlp), mode, args) ''' Schedulers ''' # Warmup Scheduler if args.warmup_epochs > 0: for param_group in optimiser.param_groups: param_group['lr'] = (1e-12 / args.warmup_epochs) * args.learning_rate # Cosine LR Decay after the warmup epochs lr_decay = lr_scheduler.CosineAnnealingLR( optimiser, (args.n_epochs - args.warmup_epochs), eta_min=0.0, last_epoch=-1) else: # Cosine LR Decay lr_decay = lr_scheduler.CosineAnnealingLR(optimiser, args.n_epochs, eta_min=0.0, last_epoch=-1) ''' Loss / Criterion ''' criterion = torch.nn.CrossEntropyLoss().cuda() # initilize Variables args.writer = SummaryWriter(args.summaries_dir) best_valid_loss = np.inf patience_counter = 0 n_batches = len(dataloaders['train']) print(f"n_batches: {n_batches}") ''' Pretrain loop ''' for epoch in range(args.n_epochs): # Train models encoder.train() mlp.train() # Print setup for distributed only printing on one node. if args.print_progress: logging.info('\nEpoch {}/{}:\n'.format(epoch + 1, args.n_epochs)) # tqdm for process (rank) 0 only when using distributed training train_dataloader = tqdm(dataloaders['train']) else: train_dataloader = dataloaders['train'] if (epoch == args.n_epochs - 1): with torch.profiler.profile(activities=[ torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA ], profile_memory=True, with_stack=True, with_flops=True, schedule=torch.profiler.schedule( wait=n_batches - 3, warmup=1, active=1, repeat=0)) as p: epoch_pretrain_loss = supervised_train_epoch(encoder, mlp, args, train_dataloader, optimiser, criterion, lr_decay, epoch, profiler=p) table = key_averages_with_stack( p.profiler.function_events).table( sort_by="self_cuda_time_total", row_limit=-1, top_level_events_only=False) print(table) p.export_stacks( os.path.join(args.model_dir, 'profiler_pretrain.stacks'), 'self_cuda_time_total') #write table to txt file with open( os.path.join(args.model_dir, 'profiler_pretrain_supervised.txt'), 'w') as profiler_log: profiler_log.write(table) #write the profiler output to csv with custom function save_events_table( key_averages_with_stack(p.profiler.function_events), os.path.join(args.model_dir, 'profiler_pretrain_supervised.csv'), times_path=os.path.join( args.model_dir, 'final_times_pretrain_supervised.txt'), row_limit=-1, top_level_events_only=False) else: epoch_pretrain_loss = supervised_train_epoch(encoder, mlp, args, train_dataloader, optimiser, criterion, lr_decay, epoch, profiler=None) # For the best performing epoch, reset patience and save model, # else update patience. if epoch_pretrain_loss <= best_valid_loss: patience_counter = 0 best_epoch = epoch + 1 best_valid_loss = epoch_pretrain_loss else: patience_counter += 1 if patience_counter == (args.patience - 10): logging.info('\nPatience counter {}/{}.'.format( patience_counter, args.patience)) elif patience_counter == args.patience: logging.info( '\nEarly stopping... no improvement after {} Epochs.'. format(args.patience)) break epoch_pretrain_loss = None # reset loss torch.cuda.empty_cache() gc.collect() # release unreferenced memory
def supervised(encoder, mlp, dataloaders, args): ''' Supervised Train script - SimCLR Supervised Training encoder and train the supervised classification head with a Cross Entropy Loss. ''' mode = 'pretrain' ''' Optimisers ''' # Only optimise the supervised head optimiser = get_optimiser((encoder, mlp), mode, args) ''' Schedulers ''' # Warmup Scheduler if args.warmup_epochs > 0: for param_group in optimiser.param_groups: param_group['lr'] = (1e-12 / args.warmup_epochs) * args.learning_rate # Cosine LR Decay after the warmup epochs lr_decay = lr_scheduler.CosineAnnealingLR( optimiser, (args.n_epochs - args.warmup_epochs), eta_min=0.0, last_epoch=-1) else: # Cosine LR Decay lr_decay = lr_scheduler.CosineAnnealingLR(optimiser, args.n_epochs, eta_min=0.0, last_epoch=-1) ''' Loss / Criterion ''' criterion = torch.nn.CrossEntropyLoss().cuda() # initilize Variables args.writer = SummaryWriter(args.summaries_dir) best_valid_loss = np.inf patience_counter = 0 ''' Pretrain loop ''' for epoch in range(args.n_epochs): # Train models encoder.train() mlp.train() sample_count = 0 run_loss = 0 run_top1 = 0.0 run_top5 = 0.0 # Print setup for distributed only printing on one node. if args.print_progress: logging.info('\nEpoch {}/{}:\n'.format(epoch + 1, args.n_epochs)) # tqdm for process (rank) 0 only when using distributed training train_dataloader = tqdm(dataloaders['train']) else: train_dataloader = dataloaders['train'] ''' epoch loop ''' for i, (inputs, target) in enumerate(train_dataloader): inputs = inputs.cuda(non_blocking=True) target = target.cuda(non_blocking=True) # Forward pass optimiser.zero_grad() h = encoder(inputs) # Take pretrained encoder representations output = mlp(h) loss = criterion(output, target) loss.backward() optimiser.step() torch.cuda.synchronize() sample_count += inputs.size(0) run_loss += loss.item() predicted = output.argmax(1) acc = (predicted == target).sum().item() / target.size(0) run_top1 += acc _, output_topk = output.topk(5, 1, True, True) acc_top5 = (output_topk == target.view( -1, 1).expand_as(output_topk)).sum().item() / target.size( 0) # num corrects run_top5 += acc_top5 epoch_pretrain_loss = run_loss / len( dataloaders['train']) # sample_count epoch_pretrain_acc = run_top1 / len(dataloaders['train']) epoch_pretrain_acc_top5 = run_top5 / len(dataloaders['train']) ''' Update Schedulers ''' # TODO: Improve / add lr_scheduler for warmup if args.warmup_epochs > 0 and epoch + 1 <= args.warmup_epochs: wu_lr = (float(epoch + 1) / args.warmup_epochs) * args.learning_rate save_lr = optimiser.param_groups[0]['lr'] optimiser.param_groups[0]['lr'] = wu_lr else: # After warmup, decay lr with CosineAnnealingLR lr_decay.step() ''' Printing ''' if args.print_progress: # only validate using process 0 logging.info('\n[Train] loss: {:.4f}'.format(epoch_pretrain_loss)) args.writer.add_scalars('epoch_loss', {'pretrain': epoch_pretrain_loss}, epoch + 1) args.writer.add_scalars('supervised_epoch_acc', {'pretrain': epoch_pretrain_acc}, epoch + 1) args.writer.add_scalars('supervised_epoch_acc_top5', {'pretrain': epoch_pretrain_acc_top5}, epoch + 1) args.writer.add_scalars('epoch_loss', {'pretrain': epoch_pretrain_loss}, epoch + 1) args.writer.add_scalars( 'lr', {'pretrain': optimiser.param_groups[0]['lr']}, epoch + 1) state = { #'args': args, 'encoder': encoder.state_dict(), 'mlp': mlp.state_dict(), 'optimiser': optimiser.state_dict(), 'epoch': epoch, } torch.save(state, args.checkpoint_dir) # For the best performing epoch, reset patience and save model, # else update patience. if epoch_pretrain_loss <= best_valid_loss: patience_counter = 0 best_epoch = epoch + 1 best_valid_loss = epoch_pretrain_loss else: patience_counter += 1 if patience_counter == (args.patience - 10): logging.info('\nPatience counter {}/{}.'.format( patience_counter, args.patience)) elif patience_counter == args.patience: logging.info( '\nEarly stopping... no improvement after {} Epochs.'. format(args.patience)) break epoch_pretrain_loss = None # reset loss del state torch.cuda.empty_cache() gc.collect() # release unreferenced memory