def train_one_epoch(student, teacher, teacher_without_ddp, dino_loss, data_loader, optimizer, lr_schedule, wd_schedule, momentum_schedule,epoch, fp16_scaler, args): metric_logger = utils.MetricLogger(delimiter=" ") header = 'Epoch: [{}/{}]'.format(epoch, args.epochs) for it, (images, _) in enumerate(metric_logger.log_every(data_loader, 10, header)): # update weight decay and learning rate according to their schedule it = len(data_loader) * epoch + it # global training iteration for i, param_group in enumerate(optimizer.param_groups): param_group["lr"] = lr_schedule[it] if i == 0: # only the first group is regularized param_group["weight_decay"] = wd_schedule[it] # move images to gpu images = [im.cuda(non_blocking=True) for im in images] # teacher and student forward passes + compute dino loss with torch.cuda.amp.autocast(fp16_scaler is not None): teacher_output = teacher(images[:2]) # only the 2 global views pass through the teacher student_output = student(images) loss = dino_loss(student_output, teacher_output, epoch) if not math.isfinite(loss.item()): print("Loss is {}, stopping training".format(loss.item()), force=True) sys.exit(1) # student update optimizer.zero_grad() param_norms = None if fp16_scaler is None: loss.backward() if args.clip_grad: param_norms = utils.clip_gradients(student, args.clip_grad) utils.cancel_gradients_last_layer(epoch, student, args.freeze_last_layer) optimizer.step() else: fp16_scaler.scale(loss).backward() if args.clip_grad: fp16_scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place param_norms = utils.clip_gradients(student, args.clip_grad) utils.cancel_gradients_last_layer(epoch, student, args.freeze_last_layer) fp16_scaler.step(optimizer) fp16_scaler.update() # EMA update for the teacher with torch.no_grad(): m = momentum_schedule[it] # momentum parameter for param_q, param_k in zip(student.module.parameters(), teacher_without_ddp.parameters()): param_k.data.mul_(m).add_((1 - m) * param_q.detach().data) # logging torch.cuda.synchronize() metric_logger.update(loss=loss.item()) metric_logger.update(lr=optimizer.param_groups[0]["lr"]) metric_logger.update(wd=optimizer.param_groups[0]["weight_decay"]) # gather the stats from all processes metric_logger.synchronize_between_processes() print("Averaged stats:", metric_logger) return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
def construct_graph(self, model_cls): # model self.model = model_cls() # loss and accuracy train_ops = _construct_loss_and_eval_ops_for_regression(self.model, self.train_data, is_training=True) self.train_loss, self.train_tr_metric, self.train_val_metric = train_ops[0], train_ops[1], train_ops[2] eval_ops = _construct_loss_and_eval_ops_for_regression(self.model, self.eval_data, is_training=False) self.eval_loss, self.eval_tr_metric, self.eval_val_metric, self.eval_tr_input, self.eval_tr_output, \ self.eval_tr_func, self.eval_val_input, self.eval_val_output, self.eval_val_func, self.eval_val_preds, \ self.eval_val_sigma = eval_ops # optimisation training_variables = tf.trainable_variables() training_gradients = tf.gradients(self.train_loss, training_variables) training_gradients = utils.clip_gradients(training_gradients, 0, 0) optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=FLAGS.outer_lr) self.global_step = tf.compat.v1.train.get_or_create_global_step() self.train_op = optimizer.apply_gradients( list(zip(training_gradients, training_variables)), self.global_step)
def construct_graph(self, model_cls): # construct model self.model = model_cls() # construct loss and accuracy ops self.train_loss, self.train_tr_metric, self.train_val_metric = \ _construct_loss_and_eval_ops_for_classification(self.model, self.train_data, is_training=True) self.train_eval_loss, self.train_eval_tr_metric, self.train_eval_val_metric = \ _construct_loss_and_eval_ops_for_classification(self.model, self.train_data, is_training=False) self.eval_loss, self.eval_tr_metric, self.eval_val_metric = \ _construct_loss_and_eval_ops_for_classification(self.model, self.eval_data, is_training=False) self.test_loss, self.test_tr_metric, self.test_val_metric = \ _construct_loss_and_eval_ops_for_classification(self.model, self.test_data, is_training=False) # construct optimisation ops training_variables = tf.compat.v1.trainable_variables() training_gradients = tf.gradients(self.train_loss, training_variables) training_gradients = utils.clip_gradients(training_gradients, 0, 0) # gradient clipping is not used optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=FLAGS.outer_lr) self.global_step = tf.compat.v1.train.get_or_create_global_step() self.train_op = optimizer.apply_gradients( list(zip(training_gradients, training_variables)), self.global_step)
def train_joint( data_loader, encoder, decoder, encoder_optimizer, decoder_optimizer, epoch, grad_clip, print_freq, gradnorm_optimizer, loss_weight_generation, loss_weight_ranking, gradnorm_loss, gradnorm_alpha, initial_generation_loss, initial_ranking_loss, ): """ Perform one training epoch. """ loss_weights = [loss_weight_generation, loss_weight_ranking] decoder.train() if encoder: encoder.train() losses = AverageMeter() # Loop over training batches for i, (images, target_captions, caption_lengths) in enumerate(data_loader): target_captions = target_captions.to(device) caption_lengths = caption_lengths.to(device) images = images.to(device) # Forward propagation if encoder: images = encoder(images) decode_lengths = caption_lengths.squeeze(1) - 1 scores, decode_lengths, images_embedded, captions_embedded, alphas = decoder.forward_joint( images, target_captions, decode_lengths) loss_generation = decoder.loss(scores, target_captions, decode_lengths, alphas) loss_ranking = decoder.loss_ranking(images_embedded, captions_embedded) loss = loss_weights[0] * loss_generation + loss_weights[ 1] * loss_ranking decoder_optimizer.zero_grad() if encoder_optimizer: encoder_optimizer.zero_grad() loss.backward(retain_graph=True) # Get the gradients of the shared layers and calculate their l2-norm named_params = dict(decoder.named_parameters()) shared_params = [ param for param_name, param in named_params.items() if param_name in decoder.SHARED_PARAMS and param.requires_grad ] G1R = torch.autograd.grad(loss_generation, shared_params, retain_graph=True, create_graph=True) G1R_flattened = torch.cat([g.view(-1) for g in G1R]) G1 = torch.norm(loss_weights[0] * G1R_flattened.data, 2).unsqueeze(0) G2R = torch.autograd.grad(loss_ranking, shared_params) G2R_flattened = torch.cat([g.view(-1) for g in G2R]) G2 = torch.norm(loss_weights[1] * G2R_flattened.data, 2).unsqueeze(0) # Calculate the average gradient norm across all tasks G_avg = torch.div(torch.add(G1, G2), 2) # Calculate relative losses lhat1 = torch.div(loss_generation, initial_generation_loss) lhat2 = torch.div(loss_ranking, initial_ranking_loss) lhat_avg = torch.div(torch.add(lhat1, lhat2), 2) # Calculate relative inverse training rates inv_rate1 = torch.div(lhat1, lhat_avg) inv_rate2 = torch.div(lhat2, lhat_avg) # Calculate the gradient norm target for this batch C1 = G_avg * (inv_rate1**gradnorm_alpha) C2 = G_avg * (inv_rate2**gradnorm_alpha) # Calculate the gradnorm loss Lgrad = torch.add(gradnorm_loss(G1, C1.data), gradnorm_loss(G2, C2.data)) # Backprop and perform an optimization step gradnorm_optimizer.zero_grad() Lgrad.backward() gradnorm_optimizer.step() # Clip gradients if grad_clip: clip_gradients(decoder_optimizer, grad_clip) if encoder_optimizer: clip_gradients(encoder_optimizer, grad_clip) # Update weights decoder_optimizer.step() if encoder_optimizer: encoder_optimizer.step() # Keep track of metrics losses.update(loss.item(), sum(decode_lengths).item()) # Log status if i % print_freq == 0: logging.info( "Epoch: {0}[Batch {1}/{2}]\t" "Loss: {loss.val:.4f} (Average: {loss.avg:.4f})\t Loss weights: Generation: {3:.4f} Ranking: {4:.4f}" .format( epoch, i, len(data_loader), loss_weights[0].item(), loss_weights[1].item(), loss=losses, )) # Renormalize the gradnorm weights coef = 2 / torch.add(loss_weight_generation, loss_weight_ranking) loss_weights = [ coef * loss_weight_generation, coef * loss_weight_ranking ] logging.info("\n * LOSS - {loss.avg:.3f}\n".format(loss=losses))
def train(model, train_loader, val_loader, logger, optim, output, gpu = True, **train_config): # 定义优化器,学习率 # torch.autograd.set_detect_anomaly(True) accuracy = VQAAccuracy() # 计算精度 # class_frequence = torch.zeros_like(train_loader.dataset[0]['answer'][:-50]) # for sample in train_loader.dataset: # class_frequence += torch.ceil(sample['answer'][:-50]) # class_weight = torch.exp(-class_frequence/class_frequence.sum()) # print(class_weight) lbce = LogitBinaryCrossEntropy() # 设置优化函数进行学习率 scheduler_func = lambda x: lr_lambda_update(x, **train_config) lr_scheduler = LambdaLR(optim, lr_lambda = scheduler_func) # lr_scheduler = ExponentialLR(optim, 0.5**(1/50000)) iteration = 0 best_val_accuracy = 0 best_epoch = 0 patient = 0 saving_epoch = 4 log_train_accuracy = {} log_train_loss = {} if val_loader is not None: log_val_loss = {} log_val_accuracy = {} for epoch in range(1,train_config["epoch"]+1): # while iteration < train_config["max_iterations"]: model.train() log_msg = "Epoch %d of Train:"%( epoch ) print(log_msg) total_accuracy = 0 total_loss = 0 for i, sample in enumerate(train_loader): iteration += 1 input_ids = Variable(sample["input_ids"]) token_type_ids = Variable(sample["token_type_ids"]) attention_mask = Variable(sample["attention_mask"]) img = Variable(sample["img_feature"]) # context = Variable(torch.zeros_like(sample["context_feature"])) context = Variable(sample["context_feature"]) labels = Variable(sample['answer']) bbox = Variable(sample['bbox']) ocrbbox = Variable(sample['ocrbbox']) if gpu: input_ids = input_ids.cuda() token_type_ids = token_type_ids.cuda() attention_mask = attention_mask.cuda() img = img.cuda() context = context.cuda() labels = labels.cuda() bbox = bbox.cuda() ocrbbox = ocrbbox.cuda() batch_size = img.size(0) prediction = model(img, bbox, input_ids, token_type_ids, attention_mask, context, ocrbbox) loss, lc, lo = lbce(labels, prediction) if epoch<=13: loss.backward() elif epoch%2 ==0 : lc.backward() else : lo.backward() # if (i+1)%2==0: lr_scheduler.step(epoch) clip_gradients(model, train_config["max_grad_l2_norm"], train_config["clip_norm_mode"]) optim.step() optim.zero_grad() # 统计精度 batch_accuracy = accuracy(labels.data,prediction.data) total_accuracy += batch_accuracy * batch_size # 统计loss total_loss += loss.data if (i+1)%10 == 0: log_msg = "[%d/%d/%d] iter:%d accuracy:%2.2f loss:%.5f lr: %f"%( epoch, len(train_loader), i, iteration, batch_accuracy*100, loss.data, optim.param_groups[0]['lr'] ) print(log_msg) if val_loader is not None: log_msg = "Epoch %d of Valuation:"%( epoch ) print(log_msg) val_accuracy, val_loss = evaluate(model, val_loader, logger, gpu, **train_config) print("Result") log_msg = "Train accuracy:%2.2f, Train loss: %.5f"%( total_accuracy/len(train_loader.dataset)*100, total_loss/len(train_loader.dataset) ) if val_loader is not None: log_msg += ", Val accuracy:%2.2f, Val loss: %.5f" % (val_accuracy*100, val_loss) print(log_msg) log_train_accuracy["epoch "+str(epoch)] = total_accuracy.cpu().numpy().tolist()/len(train_loader.dataset)*100 log_train_loss["epoch "+str(epoch)] = total_loss.cpu().numpy().tolist()/len(train_loader.dataset) if val_loader is not None: log_val_accuracy["epoch "+str(epoch)] = val_accuracy.cpu().numpy().tolist()*100 log_val_loss["epoch "+str(epoch)] = val_loss.cpu().numpy().tolist() if (val_loader is not None and val_accuracy > best_val_accuracy) or (val_loader is None and epoch >= saving_epoch): model_name = 'model_%s.pth'%('epoch'+str(epoch) if val_loader is None else 'best') model_path = os.path.join(output, model_name) utils.save_model(model_path, model, epoch, optim) if val_loader is not None: best_val_accuracy = val_accuracy best_epoch = epoch patient = 0 elif val_loader is not None: patient += 1 if patient >= 15: print("Patient %d early stop!!"%patient) break print("Patient %d"%patient) log_msg = "best val accuracy : %2.2f at %d epoch. "%( best_val_accuracy*100, best_epoch) print(log_msg) logger.add("train accuracy", log_train_accuracy) logger.add("train loss", log_train_loss) if val_loader is not None: logger.add("best val accuracy", best_val_accuracy.cpu().numpy().tolist()) logger.add("best_epoch", best_epoch) logger.add("val loss", log_val_loss) logger.add("val accuracy", log_val_accuracy) logger.save_log()
def train( model_name, data_loader, encoder, decoder, encoder_optimizer, decoder_optimizer, epoch, grad_clip, print_freq, ): """ Perform one training epoch. """ decoder.train() if encoder: encoder.train() losses = AverageMeter() # Loop over training batches for i, (images, target_captions, caption_lengths) in enumerate(data_loader): target_captions = target_captions.to(device) caption_lengths = caption_lengths.to(device) images = images.to(device) # Forward propagation if encoder: images = encoder(images) decode_lengths = caption_lengths.squeeze(1) - 1 if model_name == MODEL_BOTTOM_UP_TOP_DOWN_RANKING: scores, decode_lengths, images_embedded, captions_embedded, alphas = decoder.forward_joint( images, target_captions, decode_lengths) loss = decoder.loss(scores, target_captions, decode_lengths, alphas) else: scores, decode_lengths, alphas = decoder(images, target_captions, decode_lengths) loss = decoder.loss(scores, target_captions, decode_lengths, alphas) decoder_optimizer.zero_grad() if encoder_optimizer: encoder_optimizer.zero_grad() loss.backward() # Clip gradients if grad_clip: clip_gradients(decoder_optimizer, grad_clip) if encoder_optimizer: clip_gradients(encoder_optimizer, grad_clip) # Update weights decoder_optimizer.step() if encoder_optimizer: encoder_optimizer.step() # Keep track of metrics losses.update(loss.item(), sum(decode_lengths).item()) # Log status if i % print_freq == 0: logging.info( "Epoch: {0}[Batch {1}/{2}]\t" "Loss: {loss.val:.4f} (Average: {loss.avg:.4f})\t".format( epoch, i, len(data_loader), loss=losses)) logging.info("\n * LOSS - {loss.avg:.3f}\n".format(loss=losses))
def main(argv): del argv # unused np.random.seed(FLAGS.seed) #tf.compat.v1.set_random_seed(FLAGS.seed) print("Testing: ", FLAGS.testing, f"\t Seed: {FLAGS.seed}") FLAGS.encoder_sizes = [int(size) for size in FLAGS.encoder_sizes] FLAGS.decoder_sizes = [int(size) for size in FLAGS.decoder_sizes] if 0 in FLAGS.encoder_sizes: FLAGS.encoder_sizes.remove(0) if 0 in FLAGS.decoder_sizes: FLAGS.decoder_sizes.remove(0) # Make up full exp name timestamp = datetime.now().strftime("%y%m%d") full_exp_name = "{}_{}".format(timestamp, FLAGS.exp_name) outdir = os.path.join(FLAGS.basedir, full_exp_name) if not os.path.exists(outdir): os.makedirs(outdir) checkpoint_prefix = os.path.join(outdir, "ckpt") print("Full exp name: ", full_exp_name) ################################### # Define data specific parameters # ################################### if FLAGS.data_type == "hmnist": FLAGS.data_dir = "data/hmnist/hmnist_mnar.npz" data_dim = 784 time_length = 10 num_classes = 10 decoder = BernoulliDecoder img_shape = (28, 28, 1) val_split = 50000 elif FLAGS.data_type == "physionet": if FLAGS.data_dir == "": FLAGS.data_dir = "data/physionet/physionet.npz" data_dim = 35 time_length = 48 num_classes = 2 decoder = GaussianDecoder elif FLAGS.data_type == "sprites": if FLAGS.data_dir == "": FLAGS.data_dir = "data/sprites/sprites.npz" data_dim = 12288 time_length = 8 decoder = GaussianDecoder img_shape = (64, 64, 3) val_split = 8000 else: raise ValueError( "Data type must be one of ['hmnist', 'physionet', 'sprites']") ############# # Load data # ############# data = np.load(FLAGS.data_dir) x_train_full = data['x_train_full'] x_train_miss = data['x_train_miss'] m_train_miss = data['m_train_miss'] if FLAGS.data_type in ['hmnist', 'physionet']: y_train = data['y_train'] if FLAGS.testing: if FLAGS.data_type in ['hmnist', 'sprites']: x_val_full = data['x_test_full'] x_val_miss = data['x_test_miss'] m_val_miss = data['m_test_miss'] if FLAGS.data_type == 'hmnist': y_val = data['y_test'] elif FLAGS.data_type == 'physionet': x_val_full = data['x_train_full'] x_val_miss = data['x_train_miss'] m_val_miss = data['m_train_miss'] y_val = data['y_train'] m_val_artificial = data["m_train_artificial"] elif FLAGS.data_type in ['hmnist', 'sprites']: x_val_full = x_train_full[val_split:] x_val_miss = x_train_miss[val_split:] m_val_miss = m_train_miss[val_split:] if FLAGS.data_type == 'hmnist': y_val = y_train[val_split:] x_train_full = x_train_full[:val_split] x_train_miss = x_train_miss[:val_split] m_train_miss = m_train_miss[:val_split] y_train = y_train[:val_split] elif FLAGS.data_type == 'physionet': x_val_full = data["x_val_full"] # full for artificial missings x_val_miss = data["x_val_miss"] m_val_miss = data["m_val_miss"] m_val_artificial = data["m_val_artificial"] y_val = data["y_val"] else: raise ValueError( "Data type must be one of ['hmnist', 'physionet', 'sprites']") tf_x_train_miss = DataLoader(MyDataset(x_train_miss, m_train_miss), shuffle=True, batch_size=FLAGS.batch_size) tf_x_val_miss = iter( DataLoader(MyDataset(x_val_miss, m_val_miss), batch_size=FLAGS.batch_size, shuffle=False)) tf_x_test_miss = DataLoader(MyDataset(x_val_miss, m_val_miss), batch_size=len(x_val_miss), shuffle=False) # Build Conv2D preprocessor for image data if FLAGS.data_type in ['hmnist', 'sprites']: print("Using CNN preprocessor") image_preprocessor = ImagePreprocessor(img_shape, FLAGS.cnn_sizes, FLAGS.cnn_kernel_size) elif FLAGS.data_type == 'physionet': image_preprocessor = None else: raise ValueError( "Data type must be one of ['hmnist', 'physionet', 'sprites']") ############### # Build model # ############### if FLAGS.model_type == "vae": model = VAE(latent_dim=FLAGS.latent_dim, data_dim=data_dim, time_length=time_length, encoder_sizes=FLAGS.encoder_sizes, encoder=DiagonalEncoder, decoder_sizes=FLAGS.decoder_sizes, decoder=decoder, image_preprocessor=image_preprocessor, window_size=FLAGS.window_size, beta=FLAGS.beta, M=FLAGS.M, K=FLAGS.K) elif FLAGS.model_type == "hi-vae": model = HI_VAE(latent_dim=FLAGS.latent_dim, data_dim=data_dim, time_length=time_length, encoder_sizes=FLAGS.encoder_sizes, encoder=DiagonalEncoder, decoder_sizes=FLAGS.decoder_sizes, decoder=decoder, image_preprocessor=image_preprocessor, window_size=FLAGS.window_size, beta=FLAGS.beta, M=FLAGS.M, K=FLAGS.K) elif FLAGS.model_type == "gp-vae": encoder = BandedJointEncoder if FLAGS.banded_covar else JointEncoder model = GP_VAE(latent_dim=FLAGS.latent_dim, data_dim=data_dim, time_length=time_length, encoder_sizes=FLAGS.encoder_sizes, encoder=encoder, decoder_sizes=FLAGS.decoder_sizes, decoder=decoder, kernel=FLAGS.kernel, sigma=FLAGS.sigma, length_scale=FLAGS.length_scale, kernel_scales=FLAGS.kernel_scales, image_preprocessor=image_preprocessor, window_size=FLAGS.window_size, beta=FLAGS.beta, M=FLAGS.M, K=FLAGS.K) else: raise ValueError( "Model type must be one of ['vae', 'hi-vae', 'gp-vae']") clip_gradients(model, FLAGS.gradient_clip) ######################## # Training preparation # ######################## print("GPU support: ", tf.test.is_gpu_available()) print("Training...") #_ = tf.compat.v1.train.get_or_create_global_step() #trainable_vars = model.get_trainable_vars() #optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=FLAGS.learning_rate) optimizer = optim.Adam(model.parameters(), lr=1e-3) print("Encoder: ", model.encoder) print("Decoder: ", model.decoder) if model.preprocessor is not None: print("Preprocessor: ", model.preprocessor.net.summary()) #saver = tf.compat.v1.train.Checkpoint(optimizer=optimizer, encoder=model.encoder.net, # decoder=model.decoder.net, preprocessor=model.preprocessor.net, # optimizer_step=tf.compat.v1.train.get_or_create_global_step()) else: #saver = tf.compat.v1.train.Checkpoint(optimizer=optimizer, encoder=model.encoder.net, decoder=model.decoder.net, # optimizer_step=tf.compat.v1.train.get_or_create_global_step()) pass #summary_writer = tf.contrib.summary.create_file_writer(outdir, flush_millis=10000) if FLAGS.num_steps == 0: num_steps = FLAGS.num_epochs * len(x_train_miss) // FLAGS.batch_size else: num_steps = FLAGS.num_steps if FLAGS.print_interval == 0: FLAGS.print_interval = num_steps // FLAGS.num_epochs ############ # Training # ############ losses_train = [] losses_val = [] t0 = time.time() #with summary_writer.as_default(), tf.contrib.summary.always_record_summaries(): for i, (x_seq, m_seq) in enumerate(tf_x_train_miss): if i >= num_steps: break try: #print(x_seq.shape) optimizer.zero_grad() #with tf.GradientTape() as tape: # tape.watch(trainable_vars) loss = model.compute_loss(x_seq, m_mask=m_seq) losses_train.append(loss.detach().numpy()) #grads = loss.grad #print(grads) #grads = tape.gradient(loss, trainable_vars) #grads = [np.nan_to_num(grad) for grad in grads] #t#orch.nn.utils.clip_grad_norm(model.parameters(),clip) #grads, global_norm = tf.clip_by_global_norm(grads, FLAGS.gradient_clip) optimizer.step() #optimizer.apply_gradients(zip(grads, trainable_vars), # global_step=tf.compat.v1.train.get_or_create_global_step()) # Print intermediate results if i % FLAGS.print_interval == 0: print("================================================") print("Learning rate: {} | Global gradient norm: ".format( optimizer.param_groups[0]['lr'])) print("Step {}) Time = {:2f}".format(i, time.time() - t0)) loss, nll, kl = model.compute_loss(x_seq, m_mask=m_seq, return_parts=True) print( "Train loss = {:.3f} | NLL = {:.3f} | KL = {:.3f}".format( loss, nll, kl)) torch.save(model, checkpoint_prefix) #saver.save(checkpoint_prefix) print("loss_train", loss) print("kl_train", kl) print("nll_train", nll) #tf.contrib.summary.scalar("nll_train", nll) # Validation loss x_val_batch, m_val_batch = next(tf_x_val_miss) val_loss, val_nll, val_kl = model.compute_loss( x_val_batch, m_mask=m_val_batch, return_parts=True) losses_val.append(val_loss.detach().numpy()) print("Validation loss = {:.3f} | NLL = {:.3f} | KL = {:.3f}". format(val_loss, val_nll, val_kl)) print("loss_val", val_loss) print("kl_val", val_kl) print("nll_val", val_nll) if FLAGS.data_type in ["hmnist", "sprites"]: # Draw reconstructed images x_hat = model.decode(model.encode(x_seq).sample()).mean tf.contrib.summary.image( "input_train", x_seq.reshape([-1] + list(img_shape))) tf.contrib.summary.image( "reconstruction_train", x_hat.reshape([-1] + list(img_shape))) elif FLAGS.data_type == 'physionet': # Eval MSE and AUROC on entire val set x_val_miss_batches = np.array_split(x_val_miss, FLAGS.batch_size, axis=0) x_val_full_batches = np.array_split(x_val_full, FLAGS.batch_size, axis=0) m_val_artificial_batches = np.array_split(m_val_artificial, FLAGS.batch_size, axis=0) get_val_batches = lambda: zip(x_val_miss_batches, x_val_full_batches, m_val_artificial_batches) n_missings = m_val_artificial.sum() mse_miss = np.sum([ model.compute_mse(x, y=y, m_mask=m).item() for x, y, m in get_val_batches() ]) / n_missings x_val_imputed = np.vstack([ model.decode( model.encode(x_batch).mean).mean.detach().numpy() for x_batch in x_val_miss_batches ]) x_val_imputed[m_val_miss == 0] = x_val_miss[ m_val_miss == 0] # impute gt observed values x_val_imputed = x_val_imputed.reshape( [-1, time_length * data_dim]) val_split = len(x_val_imputed) // 2 cls_model = LogisticRegression(solver='liblinear', tol=1e-10, max_iter=10000) cls_model.fit(x_val_imputed[:val_split], y_val[:val_split]) probs = cls_model.predict_proba( x_val_imputed[val_split:])[:, 1] auroc = roc_auc_score(y_val[val_split:], probs) print("MSE miss: {:.4f} | AUROC: {:.4f}".format( mse_miss, auroc)) # Update learning rate (used only for physionet with decay=0.5) if i > 0 and i % (10 * FLAGS.print_interval) == 0: optimizer._lr = max(0.5 * optimizer._lr, 0.1 * FLAGS.learning_rate) t0 = time.time() except KeyboardInterrupt: saver.save(checkpoint_prefix) if FLAGS.debug: import ipdb ipdb.set_trace() break ############## # Evaluation # ############## print("Evaluation...") # Split data on batches x_val_miss_batches = np.array_split(x_val_miss, FLAGS.batch_size, axis=0) x_val_full_batches = np.array_split(x_val_full, FLAGS.batch_size, axis=0) if FLAGS.data_type == 'physionet': m_val_batches = np.array_split(m_val_artificial, FLAGS.batch_size, axis=0) else: m_val_batches = np.array_split(m_val_miss, FLAGS.batch_size, axis=0) get_val_batches = lambda: zip(x_val_miss_batches, x_val_full_batches, m_val_batches) # Compute NLL and MSE on missing values n_missings = m_val_artificial.sum( ) if FLAGS.data_type == 'physionet' else m_val_miss.sum() nll_miss = np.sum([ model.compute_nll(x, y=y, m_mask=m).item() #.detach().numpy() for x, y, m in get_val_batches() ]) / n_missings mse_miss = np.sum([ model.compute_mse(x, y=y, m_mask=m, binary=FLAGS.data_type == "hmnist").item() #.detach().numpy() for x, y, m in get_val_batches() ]) / n_missings print("NLL miss: {:.4f}".format(nll_miss)) print("MSE miss: {:.4f}".format(mse_miss)) # Save imputed values z_mean = [ model.encode(x_batch).mean.detach().numpy() for x_batch in x_val_miss_batches ] np.save(os.path.join(outdir, "z_mean"), np.vstack(z_mean)) x_val_imputed = np.vstack( [model.decode(z_batch).mean.detach().numpy() for z_batch in z_mean]) np.save(os.path.join(outdir, "imputed_no_gt"), x_val_imputed) # impute gt observed values x_val_imputed[m_val_miss == 0] = x_val_miss[m_val_miss == 0] np.save(os.path.join(outdir, "imputed"), x_val_imputed) if FLAGS.data_type == "hmnist": # AUROC evaluation using Logistic Regression x_val_imputed = np.round(x_val_imputed) x_val_imputed = x_val_imputed.reshape([-1, time_length * data_dim]) cls_model = LogisticRegression(solver='lbfgs', multi_class='multinomial', tol=1e-10, max_iter=10000) val_split = len(x_val_imputed) // 2 cls_model.fit(x_val_imputed[:val_split], y_val[:val_split]) probs = cls_model.predict_proba(x_val_imputed[val_split:]) auprc = average_precision_score( np.eye(num_classes)[y_val[val_split:]], probs) auroc = roc_auc_score(np.eye(num_classes)[y_val[val_split:]], probs) print("AUROC: {:.4f}".format(auroc)) print("AUPRC: {:.4f}".format(auprc)) elif FLAGS.data_type == "sprites": auroc, auprc = 0, 0 elif FLAGS.data_type == "physionet": # Uncomment to preserve some z_samples and their reconstructions # for i in range(5): # z_sample = [model.encode(x_batch).sample().numpy() for x_batch in x_val_miss_batches] # np.save(os.path.join(outdir, "z_sample_{}".format(i)), np.vstack(z_sample)) # x_val_imputed_sample = np.vstack([model.decode(z_batch).mean().numpy() for z_batch in z_sample]) # np.save(os.path.join(outdir, "imputed_sample_{}_no_gt".format(i)), x_val_imputed_sample) # x_val_imputed_sample[m_val_miss == 0] = x_val_miss[m_val_miss == 0] # np.save(os.path.join(outdir, "imputed_sample_{}".format(i)), x_val_imputed_sample) # AUROC evaluation using Logistic Regression x_val_imputed = x_val_imputed.reshape([-1, time_length * data_dim]) val_split = len(x_val_imputed) // 2 cls_model = LogisticRegression(solver='liblinear', tol=1e-10, max_iter=10000) cls_model.fit(x_val_imputed[:val_split], y_val[:val_split]) probs = cls_model.predict_proba(x_val_imputed[val_split:])[:, 1] auprc = average_precision_score(y_val[val_split:], probs) auroc = roc_auc_score(y_val[val_split:], probs) print("AUROC: {:.4f}".format(auroc)) print("AUPRC: {:.4f}".format(auprc)) # Visualize reconstructions if FLAGS.data_type in ["hmnist", "sprites"]: img_index = 0 if FLAGS.data_type == "hmnist": img_shape = (28, 28) cmap = "gray" elif FLAGS.data_type == "sprites": img_shape = (64, 64, 3) cmap = None fig, axes = plt.subplots(nrows=3, ncols=x_val_miss.shape[1], figsize=(2 * x_val_miss.shape[1], 6)) x_hat = model.decode( model.encode(x_val_miss[img_index:img_index + 1]).mean()).mean().numpy() seqs = [ x_val_miss[img_index:img_index + 1], x_hat, x_val_full[img_index:img_index + 1] ] for axs, seq in zip(axes, seqs): for ax, img in zip(axs, seq[0]): ax.imshow(img.reshape(img_shape), cmap=cmap) ax.axis('off') suptitle = FLAGS.model_type + f" reconstruction, NLL missing = {mse_miss}" fig.suptitle(suptitle, size=18) fig.savefig( os.path.join(outdir, FLAGS.data_type + "_reconstruction.pdf")) results_all = [ FLAGS.seed, FLAGS.model_type, FLAGS.data_type, FLAGS.kernel, FLAGS.beta, FLAGS.latent_dim, FLAGS.num_epochs, FLAGS.batch_size, FLAGS.learning_rate, FLAGS.window_size, FLAGS.kernel_scales, FLAGS.sigma, FLAGS.length_scale, len(FLAGS.encoder_sizes), FLAGS.encoder_sizes[0] if len(FLAGS.encoder_sizes) > 0 else 0, len(FLAGS.decoder_sizes), FLAGS.decoder_sizes[0] if len(FLAGS.decoder_sizes) > 0 else 0, FLAGS.cnn_kernel_size, FLAGS.cnn_sizes, nll_miss, mse_miss, losses_train[-1], losses_val[-1], auprc, auroc, FLAGS.testing, FLAGS.data_dir ] with open(os.path.join(outdir, "results.tsv"), "w") as outfile: outfile.write( "seed\tmodel\tdata\tkernel\tbeta\tz_size\tnum_epochs" "\tbatch_size\tlearning_rate\twindow_size\tkernel_scales\t" "sigma\tlength_scale\tencoder_depth\tencoder_width\t" "decoder_depth\tdecoder_width\tcnn_kernel_size\t" "cnn_sizes\tNLL\tMSE\tlast_train_loss\tlast_val_loss\tAUPRC\tAUROC\ttesting\tdata_dir\n" ) outfile.write("\t".join(map(str, results_all))) with open(os.path.join(outdir, "training_curve.tsv"), "w") as outfile: outfile.write("\t".join(map(str, losses_train))) outfile.write("\n") outfile.write("\t".join(map(str, losses_val))) print("Training finished.")
def main(): is_training = tf.placeholder(tf.bool, name='is_training') num_class = args.num_way num_shot = args.num_shot num_query = args.num_query keep_prob = tf.placeholder(tf.float32, name='keep_prob') support_label = tf.placeholder(tf.int32, (None, ), 'support_label') query_label = tf.placeholder(tf.int32, (None, ), 'query_label') support_x = tf.placeholder(tf.float32, (None, 640), 'support_x') query_x = tf.placeholder(tf.float32, (None, 640), 'query_x') support_feature = support_x query_feature = query_x support_feature = tf.reshape(support_feature, (batch_size, num_class, num_shot, 640)) query_feature = tf.reshape(query_feature, (batch_size, num_class, num_query, 640)) support_label_reshape = tf.reshape(support_label, (batch_size, num_class, num_shot)) query_label_reshape = tf.reshape(query_label, (batch_size, num_class, num_query)) awgim = model.AWGIM(args, keep_prob, is_training) loss_cls, accuracy, tr_loss, tr_accuracy, support_reconstruction, query_reconstruction = \ awgim.forward(support_feature, support_label_reshape, query_feature, query_label_reshape) reg_term = tf.add_n([ tf.nn.l2_loss(v) for v in tf.trainable_variables() if 'kernel' in v.name ]) loss_meta = loss_cls + args.alpha_1 * tr_loss + args.alpha_2 * support_reconstruction + args.alpha_3 * query_reconstruction Batch = tf.Variable(0, trainable=False, dtype=tf.float32, name='global_step') learning_rate = tf.train.exponential_decay( learning_rate=args.learning_rate, global_step=Batch, decay_steps=args.step_size, decay_rate=0.2, staircase=True) optim = tf.contrib.opt.AdamWOptimizer(learning_rate=learning_rate, weight_decay=args.weight_decay) meta_weights = [v for v in tf.trainable_variables()] print(meta_weights) if args.stage == 'train': meta_gradients = utils.grads_and_vars(loss_meta, meta_weights, reg_term) meta_gradients = utils.clip_gradients(meta_gradients, args.gradient_threshold, args.gradient_norm_threshold) train_op = optim.apply_gradients(zip(meta_gradients, meta_weights), global_step=Batch) config = tf.ConfigProto() config.gpu_options.allow_growth = True sess = tf.Session(config=config) saver = tf.train.Saver() sess.run(tf.global_variables_initializer()) save_path = utils.save(args) print(save_path) os.makedirs(save_path, exist_ok=True) if args.stage == 'test': print(tf.train.latest_checkpoint(save_path)) saver.restore(sess, tf.train.latest_checkpoint(save_path)) print('load model') if args.data_set == 'mini': loader_train = dataset_mini.dataset_mini('train', args) loader_val = dataset_mini.dataset_mini('val', args) loader_test = dataset_mini.dataset_mini('test', args) else: loader_train = dataset_tiered.dataset_tiered('train', args) loader_val = dataset_tiered.dataset_tiered('val', args) loader_test = dataset_tiered.dataset_tiered('test', args) if args.stage == 'train': print('Load PKL data') loader_train.load_data_pkl() loader_val.load_data_pkl() else: loader_test.load_data_pkl() val_best_accuracy = 0. n_iter = 0 record_val_acc = [] if args.stage == 'train': for epoch in range(args.epoch): training_accuracy, training_loss, acc_cp, acc_real, c_loss, d_loss, g_loss = [], [], [], [], [], [], [] # training_loss_cls = [] for epi in range(100): support_input, s_labels, query_input, q_labels = utils.load_batch( args, loader_train, args.batch_size, True, loader_val) feed_dict = { support_x: support_input, support_label: s_labels, query_x: query_input, query_label: q_labels, is_training: True, keep_prob: 1. - args.dropout } outs = sess.run([train_op, loss_meta, accuracy, Batch], feed_dict=feed_dict) training_accuracy.append(outs[2]) training_loss.append(outs[1]) n_iter += 1 if (epoch + 1) % 3 == 0: log = 'epoch: ', epoch + 1, 'accuracy: ', np.mean( training_accuracy), 'loss: ', np.mean(training_loss) print(log) if (epoch + 1) % 3 == 0: accuracy_val = [] loss_val = [] for epi in range(100): support_input, s_labels, query_input, q_labels = utils.load_batch( args, loader_val, args.batch_size, training=False) outs = sess.run( [loss_meta, accuracy, Batch], feed_dict={ support_x: support_input, support_label: s_labels, query_x: query_input, query_label: q_labels, is_training: False, keep_prob: 1. }) accuracy_val.append(outs[1]) loss_val.append(outs[0]) mean_acc = np.mean(accuracy_val) std_acc = np.std(accuracy_val) ci95 = 1.96 * std_acc / np.sqrt(100) print( ' Val Acc:{:.4f},std:{:.4f},ci95:{:.4f}'.format( mean_acc, std_acc, ci95), 'at epoch: ', epoch + 1) record_val_acc.append(mean_acc) if mean_acc > val_best_accuracy: val_best_accuracy = mean_acc saver.save(sess, save_path=save_path + 'model.ckpt', global_step=Batch) if (epoch + 1) % 100 == 0: saver.save(sess, save_path=save_path + 'model.ckpt', global_step=Batch) elif args.stage == 'test': accuracy_test = [] loss_test = [] num = 600 for epi in range(num): support_input, s_labels, query_input, q_labels = utils.load_batch( args, loader_test, args.batch_size, False) outs = sess.run( [loss_meta, accuracy], feed_dict={ support_x: support_input, support_label: s_labels, query_x: query_input, query_label: q_labels, is_training: False, keep_prob: 1. }) accuracy_test.append(outs[1]) loss_test.append(outs[0]) mean_acc = np.mean(accuracy_test) std_acc = np.std(accuracy_test) ci95 = 1.96 * std_acc / np.sqrt(num) print('Acc:{:.4f},std:{:.4f},ci95:{:.4f}'.format( mean_acc, std_acc, ci95)) sess.close()
def main(): parser = argparse.ArgumentParser( "DINO training CLI", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument("-b", "--batch-size", type=int, default=4) parser.add_argument("-d", "--device", type=str, choices=("cpu", "cuda"), default="cuda") parser.add_argument("-l", "--logging-freq", type=int, default=200) parser.add_argument("--momentum-teacher", type=int, default=0.9995) parser.add_argument("-c", "--n-crops", type=int, default=4) parser.add_argument("-e", "--n-epochs", type=int, default=100) parser.add_argument("-o", "--out-dim", type=int, default=1024) parser.add_argument("-t", "--tensorboard-dir", type=str, default="logs") parser.add_argument("--clip-grad", type=float, default=2.0) parser.add_argument("--norm-last-layer", action="store_true") parser.add_argument("--batch-size-eval", type=int, default=8) parser.add_argument("--teacher-temp", type=float, default=0.04) parser.add_argument("--student-temp", type=float, default=0.1) parser.add_argument("--pretrained", action="store_true") parser.add_argument("-w", "--weight-decay", type=float, default=0.4) args = parser.parse_args() print(vars(args)) # Parameters vit_name, dim = "deit_small_patch16_224", 384 path_dataset_train = pathlib.Path("data/imagenette2-320/train") path_dataset_val = pathlib.Path("data/imagenette2-320/val") path_labels = pathlib.Path("data/imagenette_labels.json") logging_path = pathlib.Path(args.tensorboard_dir) device = torch.device(args.device) n_workers = 1 # para mi maquinita solo 2 como maximo # Data related with path_labels.open("r") as f: label_mapping = json.load(f) transform_aug = DataAugmentation(size=224, n_local_crops=args.n_crops - 2) transform_plain = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), transforms.Resize((224, 224)), ]) dataset_train_aug = ImageFolder(path_dataset_train, transform=transform_aug) dataset_train_plain = ImageFolder(path_dataset_train, transform=transform_plain) dataset_val_plain = ImageFolder(path_dataset_val, transform=transform_plain) if dataset_train_plain.classes != dataset_val_plain.classes: raise ValueError("Inconsistent classes") data_loader_train_aug = DataLoader( dataset_train_aug, batch_size=args.batch_size, shuffle=True, drop_last=True, num_workers=n_workers, pin_memory=True, ) data_loader_train_plain = DataLoader( dataset_train_plain, batch_size=args.batch_size_eval, drop_last=False, num_workers=n_workers, ) data_loader_val_plain = DataLoader( dataset_val_plain, batch_size=args.batch_size_eval, drop_last=False, num_workers=n_workers, ) data_loader_val_plain_subset = DataLoader( dataset_val_plain, batch_size=args.batch_size_eval, drop_last=False, sampler=SubsetRandomSampler(list(range(0, len(dataset_val_plain), 50))), num_workers=n_workers, ) # Logging writer = SummaryWriter(logging_path) writer.add_text("arguments", json.dumps(vars(args))) # Neural network related student_vit = timm.create_model(vit_name, pretrained=args.pretrained) teacher_vit = timm.create_model(vit_name, pretrained=args.pretrained) student = MultiCropWrapper( student_vit, Head( dim, args.out_dim, norm_last_layer=args.norm_last_layer, ), ) teacher = MultiCropWrapper(teacher_vit, Head(dim, args.out_dim)) student, teacher = student.to(device), teacher.to(device) teacher.load_state_dict(student.state_dict()) for p in teacher.parameters(): p.requires_grad = False # Loss related loss_inst = Loss( args.out_dim, teacher_temp=args.teacher_temp, student_temp=args.student_temp, ).to(device) lr = 0.0005 * args.batch_size / 256 optimizer = torch.optim.AdamW( student.parameters(), lr=lr, weight_decay=args.weight_decay, ) # Training loop n_batches = len(dataset_train_aug) // args.batch_size best_acc = 0 n_steps = 0 for e in range(args.n_epochs): for i, (images, _) in tqdm.tqdm(enumerate(data_loader_train_aug), total=n_batches): if n_steps % args.logging_freq == 0: student.eval() # Embedding embs, imgs, labels_ = compute_embedding( student.backbone, data_loader_val_plain_subset, ) writer.add_embedding( embs, metadata=[label_mapping[l] for l in labels_], label_img=imgs, global_step=n_steps, tag="embeddings", ) # KNN current_acc = compute_knn( student.backbone, data_loader_train_plain, data_loader_val_plain, ) writer.add_scalar("knn-accuracy", current_acc, n_steps) if current_acc > best_acc: torch.save(student, logging_path / "best_model.pth") best_acc = current_acc student.train() images = [img.to(device) for img in images] teacher_output = teacher(images[:2]) student_output = student(images) loss = loss_inst(student_output, teacher_output) optimizer.zero_grad() loss.backward() clip_gradients(student, args.clip_grad) optimizer.step() with torch.no_grad(): for student_ps, teacher_ps in zip(student.parameters(), teacher.parameters()): teacher_ps.data.mul_(args.momentum_teacher) teacher_ps.data.add_( (1 - args.momentum_teacher) * student_ps.detach().data) writer.add_scalar("train_loss", loss, n_steps) n_steps += 1
def main(args): if args.model_name is not None: print('Preparing to train model: {}'.format(args.model_name)) global device device = torch.device( 'cuda' if torch.cuda.is_available() and not args.cpu else 'cpu') sc_will_happen = args.self_critical_from_epoch != -1 if args.validate is None and args.lr_scheduler == 'ReduceLROnPlateau': print( 'ERROR: you need to enable validation in order to use default lr_scheduler (ReduceLROnPlateau)' ) print('Hint: use something like --validate=coco:val2017') sys.exit(1) # Create model directory if not os.path.exists(args.model_path): os.makedirs(args.model_path) # Image preprocessing, normalization for the pretrained resnet transform = transforms.Compose([ # transforms.Resize((256, 256)), transforms.RandomCrop(args.crop_size), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) ]) scorers = {} if args.validation_scoring is not None or sc_will_happen: assert not ( args.validation_scoring is None and sc_will_happen ), "Please provide a metric when using self-critical training" for s in args.validation_scoring.split(','): s = s.lower().strip() if s == 'cider': from eval.cider import Cider scorers['CIDEr'] = Cider() if s == 'ciderd': from eval.ciderD.ciderD import CiderD scorers['CIDEr-D'] = CiderD(df=args.cached_words) ######################## # Set Model parameters # ######################## # Store parameters gotten from arguments separately: arg_params = ModelParams.fromargs(args) print("Model parameters inferred from command arguments: ") print(arg_params) start_epoch = 0 ############################### # Load existing model state # # and update Model parameters # ############################### state = None if args.load_model: try: state = torch.load(args.load_model, map_location=device) except AttributeError: print( 'WARNING: Old model found. Please use model_update.py in the model before executing this script.' ) exit(1) new_external_features = arg_params.features.external params = ModelParams(state, arg_params=arg_params) if len(new_external_features ) and params.features.external != new_external_features: print('WARNING: external features changed: ', params.features.external, new_external_features) print('Updating feature paths...') params.update_ext_features(new_external_features) start_epoch = state['epoch'] print('Loaded model {} at epoch {}'.format(args.load_model, start_epoch)) else: params = arg_params params.command_history = [] if params.rnn_hidden_init == 'from_features' and params.skip_start_token: print( "ERROR: Please remove --skip_start_token if you want to use image features " " to initialize hidden and cell states. <start> token is needed to trigger " " the process of sequence generation, since we don't have image features " " embedding as the first input token.") sys.exit(1) # Force set the following hierarchical model parameters every time: if arg_params.hierarchical_model: params.hierarchical_model = True params.max_sentences = arg_params.max_sentences params.weight_sentence_loss = arg_params.weight_sentence_loss params.weight_word_loss = arg_params.weight_word_loss params.dropout_stopping = arg_params.dropout_stopping params.dropout_fc = arg_params.dropout_fc params.coherent_sentences = arg_params.coherent_sentences params.coupling_alpha = arg_params.coupling_alpha params.coupling_beta = arg_params.coupling_beta assert args.replace or \ not os.path.isdir(os.path.join(args.output_root, args.model_path, get_model_name(args, params))) or \ not (args.load_model and not args.validate_only), \ '{} already exists. If you want to replace it or resume training please use --replace flag. ' \ 'If you want to validate a loaded model without training it, use --validate_only flag.' \ 'Otherwise specify a different model name using --model_name flag.'\ .format(os.path.join(args.output_root, args.model_path, get_model_name(args, params))) if args.load_model: print("Final model parameters (loaded model + command arguments): ") print(params) ############################## # Load dataset configuration # ############################## dataset_configs = DatasetParams(args.dataset_config_file) if args.dataset is None and not args.validate_only: print('ERROR: No dataset selected!') print( 'Please supply a training dataset with the argument --dataset DATASET' ) print('The following datasets are configured in {}:'.format( args.dataset_config_file)) for ds, _ in dataset_configs.config.items(): if ds not in ('DEFAULT', 'generic'): print(' ', ds) sys.exit(1) if args.validate_only: if args.load_model is None: print( 'ERROR: for --validate_only you need to specify a model to evaluate using --load_model MODEL' ) sys.exit(1) else: dataset_params = dataset_configs.get_params(args.dataset) for i in dataset_params: i.config_dict['no_tokenize'] = args.no_tokenize i.config_dict['show_tokens'] = args.show_tokens i.config_dict['skip_start_token'] = params.skip_start_token if params.hierarchical_model: i.config_dict['hierarchical_model'] = True i.config_dict['max_sentences'] = params.max_sentences i.config_dict['crop_regions'] = False if args.validate is not None: validation_dataset_params = dataset_configs.get_params(args.validate) for i in validation_dataset_params: i.config_dict['no_tokenize'] = args.no_tokenize i.config_dict['show_tokens'] = args.show_tokens i.config_dict['skip_start_token'] = params.skip_start_token if params.hierarchical_model: i.config_dict['hierarchical_model'] = True i.config_dict['max_sentences'] = params.max_sentences i.config_dict['crop_regions'] = False ####################### # Load the vocabulary # ####################### # For pre-trained models attempt to obtain # saved vocabulary from the model itself: if args.load_model and params.vocab is not None: print("Loading vocabulary from the model file:") vocab = params.vocab else: if args.vocab is None: print( "ERROR: You must specify the vocabulary to be used for training using " "--vocab flag.\nTry --vocab AUTO if you want the vocabulary to be " "either generated from the training dataset or loaded from cache." ) sys.exit(1) print("Loading / generating vocabulary:") vocab = get_vocab(args, dataset_params) print('Size of the vocabulary is {}'.format(len(vocab))) ########################## # Initialize data loader # ########################## ext_feature_sets = [ params.features.external, params.persist_features.external ] if not args.validate_only: print('Loading dataset: {} with {} workers'.format( args.dataset, args.num_workers)) if params.skip_start_token: print("Skipping the use of <start> token...") data_loader, ef_dims = get_loader( dataset_params, vocab, transform, args.batch_size, shuffle=True, num_workers=args.num_workers, ext_feature_sets=ext_feature_sets, skip_images=not params.has_internal_features(), verbose=args.verbose, unique_ids=sc_will_happen) if sc_will_happen: gts_sc = get_ground_truth_captions(data_loader.dataset) gts_sc_valid = None if args.validate is not None: valid_loader, ef_dims = get_loader( validation_dataset_params, vocab, transform, args.batch_size, shuffle=True, num_workers=args.num_workers, ext_feature_sets=ext_feature_sets, skip_images=not params.has_internal_features(), verbose=args.verbose) gts_sc_valid = get_ground_truth_captions( valid_loader.dataset) if sc_will_happen else None ######################################### # Setup (optional) TensorBoardX logging # ######################################### writer = None if args.tensorboard: if SummaryWriter is not None: model_name = get_model_name(args, params) timestamp = datetime.now().strftime('%Y%m%d%H%M%S') log_dir = os.path.join( args.output_root, 'log_tb/{}_{}'.format(model_name, timestamp)) writer = SummaryWriter(log_dir=log_dir) print("INFO: Logging TensorBoardX events to {}".format(log_dir)) else: print( "WARNING: SummaryWriter object not available. " "Hint: Please install TensorBoardX using pip install tensorboardx" ) ###################### # Build the model(s) # ###################### # Set per parameter learning rate here, if supplied by the user: if args.lr_word_decoder is not None: if not params.hierarchical_model: print( "ERROR: Setting word decoder learning rate currently supported in Hierarchical Model only." ) sys.exit(1) lr_dict = {'word_decoder': args.lr_word_decoder} else: lr_dict = {} model = EncoderDecoder(params, device, len(vocab), state, ef_dims, lr_dict=lr_dict) ###################### # Optimizer and loss # ###################### sc_activated = False opt_params = model.get_opt_params() # Loss and optimizer if params.hierarchical_model: criterion = HierarchicalXEntropyLoss( weight_sentence_loss=params.weight_sentence_loss, weight_word_loss=params.weight_word_loss) elif args.share_embedding_weights: criterion = SharedEmbeddingXentropyLoss(param_lambda=0.15) else: criterion = nn.CrossEntropyLoss() if sc_will_happen: # save it for later if args.self_critical_loss == 'sc': from model.loss import SelfCriticalLoss rl_criterion = SelfCriticalLoss() elif args.self_critical_loss == 'sc_with_diversity': from model.loss import SelfCriticalWithDiversityLoss rl_criterion = SelfCriticalWithDiversityLoss() elif args.self_critical_loss == 'sc_with_relative_diversity': from model.loss import SelfCriticalWithRelativeDiversityLoss rl_criterion = SelfCriticalWithRelativeDiversityLoss() elif args.self_critical_loss == 'sc_with_bleu_diversity': from model.loss import SelfCriticalWithBLEUDiversityLoss rl_criterion = SelfCriticalWithBLEUDiversityLoss() elif args.self_critical_loss == 'sc_with_repetition': from model.loss import SelfCriticalWithRepetitionLoss rl_criterion = SelfCriticalWithRepetitionLoss() elif args.self_critical_loss == 'mixed': from model.loss import MixedLoss rl_criterion = MixedLoss() elif args.self_critical_loss == 'mixed_with_face': from model.loss import MixedWithFACELoss rl_criterion = MixedWithFACELoss(vocab_size=len(vocab)) elif args.self_critical_loss in [ 'sc_with_penalty', 'sc_with_penalty_throughout', 'sc_masked_tokens' ]: raise ValueError('Deprecated loss, use \'sc\' loss') else: raise ValueError('Invalid self-critical loss') print('Selected self-critical loss is', rl_criterion) if start_epoch >= args.self_critical_from_epoch: criterion = rl_criterion sc_activated = True print('Self-critical loss training begins') # When using CyclicalLR, default learning rate should be always 1.0 if args.lr_scheduler == 'CyclicalLR': default_lr = 1. else: default_lr = 0.001 if sc_activated: optimizer = torch.optim.Adam( opt_params, lr=args.learning_rate if args.learning_rate else 5e-5, weight_decay=args.weight_decay) elif args.optimizer == 'adam': optimizer = torch.optim.Adam(opt_params, lr=default_lr, weight_decay=args.weight_decay) elif args.optimizer == 'rmsprop': optimizer = torch.optim.RMSprop(opt_params, lr=default_lr, weight_decay=args.weight_decay) elif args.optimizer == 'sgd': optimizer = torch.optim.SGD(opt_params, lr=default_lr, weight_decay=args.weight_decay) else: print('ERROR: unknown optimizer:', args.optimizer) sys.exit(1) # We don't want to initialize the optimizer if we are transfering # the language model from the regular model to hierarchical model transfer_language_model = False if arg_params.hierarchical_model and state and not state.get( 'hierarchical_model'): transfer_language_model = True # Set optimizer state to the one found in a loaded model, unless # we are doing a transfer learning step from flat to hierarchical model, # or we are using self-critical loss, # or the number of unique parameter groups has changed, or the user # has explicitly told us *not to* reuse optimizer parameters from before if state and not transfer_language_model and not sc_activated and not args.optimizer_reset: # Check that number of parameter groups is the same if len(optimizer.param_groups) == len( state['optimizer']['param_groups']): optimizer.load_state_dict(state['optimizer']) # override lr if set explicitly in arguments - # 1) Global learning rate: if args.learning_rate: for param_group in optimizer.param_groups: param_group['lr'] = args.learning_rate params.learning_rate = args.learning_rate else: params.learning_rate = default_lr # 2) Parameter-group specific learning rate: if args.lr_word_decoder is not None: # We want to give user an option to set learning rate for word_decoder # separately. Other exceptions can be added as needed: for param_group in optimizer.param_groups: if param_group.get('name') == 'word_decoder': param_group['lr'] = args.lr_word_decoder break if args.validate is not None and args.lr_scheduler == 'ReduceLROnPlateau': print('Using ReduceLROnPlateau learning rate scheduler') scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', verbose=True, patience=2) elif args.lr_scheduler == 'StepLR': print('Using StepLR learning rate scheduler with step_size {}'.format( args.lr_step_size)) # Decrease the learning rate by the factor of gamma at every # step_size epochs (for example every 5 or 10 epochs): step_size = args.lr_step_size scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size, gamma=0.5, last_epoch=-1) elif args.lr_scheduler == 'CyclicalLR': print( "Using Cyclical learning rate scheduler, lr range: [{},{}]".format( args.lr_cyclical_min, args.lr_cyclical_max)) step_size = len(data_loader) clr = cyclical_lr(step_size, min_lr=args.lr_cyclical_min, max_lr=args.lr_cyclical_max) n_groups = len(optimizer.param_groups) scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, [clr] * n_groups) elif args.lr_scheduler is not None: print('ERROR: Invalid learing rate scheduler specified: {}'.format( args.lr_scheduler)) sys.exit(1) ################### # Train the model # ################### stats_postfix = None if args.validate_only: stats_postfix = args.validate if args.load_model: all_stats = init_stats(args, params, postfix=stats_postfix) else: all_stats = {} if args.force_epoch: start_epoch = args.force_epoch - 1 if not args.validate_only: total_step = len(data_loader) print( 'Start training with start_epoch={:d} num_epochs={:d} num_batches={:d} ...' .format(start_epoch, args.num_epochs, args.num_batches)) if args.teacher_forcing != 'always': print('\t k: {}'.format(args.teacher_forcing_k)) print('\t beta: {}'.format(args.teacher_forcing_beta)) print('Optimizer:', optimizer) if args.validate_only: stats = {} teacher_p = 1.0 if args.teacher_forcing != 'always': print( 'WARNING: teacher_forcing!=always, not yet implemented for --validate_only mode' ) epoch = start_epoch - 1 if str(epoch + 1) in all_stats.keys() and args.skip_existing_validations: print('WARNING: epoch {} already validated, skipping...'.format( epoch + 1)) return val_loss = do_validate(model, valid_loader, criterion, scorers, vocab, teacher_p, args, params, stats, epoch, sc_activated, gts_sc_valid) all_stats[str(epoch + 1)] = stats save_stats(args, params, all_stats, postfix=stats_postfix) else: for epoch in range(start_epoch, args.num_epochs): stats = {} begin = datetime.now() total_loss = 0 if params.hierarchical_model: total_loss_sent = 0 total_loss_word = 0 num_batches = 0 vocab_counts = { 'cnt': 0, 'max': 0, 'min': 9999, 'sum': 0, 'unk_cnt': 0, 'unk_sum': 0 } # If start self critical training if not sc_activated and sc_will_happen and epoch >= args.self_critical_from_epoch: if all_stats: best_ep, best_cider = max( [(ep, all_stats[ep]['validation_cider']) for ep in all_stats], key=lambda x: x[1]) print('Loading model from epoch', best_ep, 'which has the better score with', best_cider) state = torch.load( get_model_path(args, params, int(best_ep))) model = EncoderDecoder(params, device, len(vocab), state, ef_dims, lr_dict=lr_dict) opt_params = model.get_opt_params() optimizer = torch.optim.Adam(opt_params, lr=5e-5, weight_decay=args.weight_decay) criterion = rl_criterion print('Self-critical loss training begins') sc_activated = True for i, data in enumerate(data_loader): if params.hierarchical_model: (images, captions, lengths, image_ids, features, sorting_order, last_sentence_indicator) = data sorting_order = sorting_order.to(device) else: (images, captions, lengths, image_ids, features) = data if epoch == 0: unk = vocab('<unk>') for j in range(captions.shape[0]): # Flatten the caption in case it's a paragraph # this is harmless for regular captions too: xl = captions[j, :].view(-1) xw = xl > unk xu = xl == unk xwi = sum(xw).item() xui = sum(xu).item() vocab_counts['cnt'] += 1 vocab_counts['sum'] += xwi vocab_counts['max'] = max(vocab_counts['max'], xwi) vocab_counts['min'] = min(vocab_counts['min'], xwi) vocab_counts['unk_cnt'] += xui > 0 vocab_counts['unk_sum'] += xui # Set mini-batch dataset images = images.to(device) captions = captions.to(device) # Remove <start> token from targets if we are initializing the RNN # hidden state from image features: if params.rnn_hidden_init == 'from_features' and not params.hierarchical_model: # Subtract one from all lengths to match new target lengths: lengths = [x - 1 if x > 0 else x for x in lengths] targets = pack_padded_sequence(captions[:, 1:], lengths, batch_first=True)[0] else: if params.hierarchical_model: targets = prepare_hierarchical_targets( last_sentence_indicator, args.max_sentences, lengths, captions, device) else: targets = pack_padded_sequence(captions, lengths, batch_first=True)[0] sorting_order = None init_features = features[0].to(device) if len( features) > 0 and features[0] is not None else None persist_features = features[1].to(device) if len( features) > 1 and features[1] is not None else None # Forward, backward and optimize # Calculate the probability whether to use teacher forcing or not: # Iterate over batches: iteration = (epoch - start_epoch) * len(data_loader) + i teacher_p = get_teacher_prob(args.teacher_forcing_k, iteration, args.teacher_forcing_beta) # Allow model to log values at the last batch of the epoch writer_data = None if writer and (i == len(data_loader) - 1 or i == args.num_batches - 1): writer_data = {'writer': writer, 'epoch': epoch + 1} sample_len = captions.size(1) if args.self_critical_loss in [ 'mixed', 'mixed_with_face' ] else 20 if sc_activated: sampled_seq, sampled_log_probs, outputs = model.sample( images, init_features, persist_features, max_seq_length=sample_len, start_token_id=vocab('<start>'), trigram_penalty_alpha=args.trigram_penalty_alpha, stochastic_sampling=True, output_logprobs=True, output_outputs=True) sampled_seq = model.decoder.alt_prob_to_tensor( sampled_seq, device=device) else: outputs = model(images, init_features, captions, lengths, persist_features, teacher_p, args.teacher_forcing, sorting_order, writer_data=writer_data) if args.share_embedding_weights: # Weights of (HxH) projection matrix used for regularizing # models that share embedding weights projection = model.decoder.projection.weight loss = criterion(projection, outputs, targets) elif sc_activated: # get greedy decoding baseline model.eval() with torch.no_grad(): greedy_sampled_seq = model.sample( images, init_features, persist_features, max_seq_length=sample_len, start_token_id=vocab('<start>'), trigram_penalty_alpha=args.trigram_penalty_alpha, stochastic_sampling=False) greedy_sampled_seq = model.decoder.alt_prob_to_tensor( greedy_sampled_seq, device=device) model.train() if args.self_critical_loss in [ 'sc', 'sc_with_diversity', 'sc_with_relative_diversity', 'sc_with_bleu_diversity', 'sc_with_repetition' ]: loss, advantage = criterion( sampled_seq, sampled_log_probs, greedy_sampled_seq, [gts_sc[i] for i in image_ids], scorers, vocab, return_advantage=True) elif args.self_critical_loss in ['mixed']: loss, advantage = criterion( sampled_seq, sampled_log_probs, outputs, greedy_sampled_seq, [gts_sc[i] for i in image_ids], scorers, vocab, targets, lengths, gamma_ml_rl=args.gamma_ml_rl, return_advantage=True) elif args.self_critical_loss in ['mixed_with_face']: loss, advantage = criterion( sampled_seq, sampled_log_probs, outputs, greedy_sampled_seq, [gts_sc[i] for i in image_ids], scorers, vocab, captions, targets, lengths, gamma_ml_rl=args.gamma_ml_rl, return_advantage=True) else: raise ValueError('Invalid self-critical loss') if writer is not None and i % 100 == 0: writer.add_scalar('training_loss', loss.item(), epoch * len(data_loader) + i) writer.add_scalar('advantage', advantage, epoch * len(data_loader) + i) writer.add_scalar('lr', optimizer.param_groups[0]['lr'], epoch * len(data_loader) + i) else: loss = criterion(outputs, targets) model.zero_grad() loss.backward() # Clip gradients if desired: if args.grad_clip is not None: # grad_norms = [x.grad.data.norm(2) for x in opt_params] # batch_max_grad = np.max(grad_norms) # if batch_max_grad > 10.0: # print('WARNING: gradient norms larger than 10.0') # torch.nn.utils.clip_grad_norm_(decoder.parameters(), 0.1) # torch.nn.utils.clip_grad_norm_(encoder.parameters(), 0.1) clip_gradients(optimizer, args.grad_clip) # Update weights: optimizer.step() # CyclicalLR requires us to update LR at every minibatch: if args.lr_scheduler == 'CyclicalLR': scheduler.step() total_loss += loss.item() num_batches += 1 if params.hierarchical_model: _, loss_sent, _, loss_word = criterion.item_terms() total_loss_sent += float(loss_sent) total_loss_word += float(loss_word) # Print log info if (i + 1) % args.log_step == 0: print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}, ' 'Perplexity: {:5.4f}'.format(epoch + 1, args.num_epochs, i + 1, total_step, loss.item(), np.exp(loss.item()))) sys.stdout.flush() if params.hierarchical_model: weight_sent, loss_sent, weight_word, loss_word = criterion.item_terms( ) print('Sentence Loss: {:.4f}, ' 'Word Loss: {:.4f}'.format( float(loss_sent), float(loss_word))) sys.stdout.flush() if i + 1 == args.num_batches: break end = datetime.now() stats['training_loss'] = total_loss / num_batches if params.hierarchical_model: stats['loss_sentence'] = total_loss_sent / num_batches stats['loss_word'] = total_loss_word / num_batches print('Epoch {} duration: {}, average loss: {:.4f}'.format( epoch + 1, end - begin, stats['training_loss'])) save_model(args, params, model.encoder, model.decoder, optimizer, epoch, vocab) if epoch == 0: vocab_counts['avg'] = vocab_counts['sum'] / vocab_counts['cnt'] vocab_counts['unk_cnt_per'] = 100 * vocab_counts[ 'unk_cnt'] / vocab_counts['cnt'] vocab_counts['unk_sum_per'] = 100 * vocab_counts[ 'unk_sum'] / vocab_counts['sum'] # print(vocab_counts) print(( 'Training data contains {sum} words in {cnt} captions (avg. {avg:.1f} w/c)' + ' with {unk_sum} <unk>s ({unk_sum_per:.1f}%)' + ' in {unk_cnt} ({unk_cnt_per:.1f}%) captions').format( **vocab_counts)) ############################################ # Validation loss and learning rate update # ############################################ if args.validate is not None and (epoch + 1) % args.validation_step == 0: val_loss = do_validate(model, valid_loader, criterion, scorers, vocab, teacher_p, args, params, stats, epoch, sc_activated, gts_sc_valid) if args.lr_scheduler == 'ReduceLROnPlateau': scheduler.step(val_loss) elif args.lr_scheduler == 'StepLR': scheduler.step() all_stats[str(epoch + 1)] = stats save_stats(args, params, all_stats, writer=writer) if writer is not None: # Log model data to tensorboard log_model_data(params, model, epoch + 1, writer) if writer is not None: writer.close()
def main(): parser = argparse.ArgumentParser( "DINO training CLI", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument( "-m", "--model", type=str, default="vit_tiny", choices=["vit_tiny", "vit_small", "vit_base"], ) parser.add_argument("-b", "--batch-size", type=int, default=32) parser.add_argument("-d", "--device", type=int, default=0) parser.add_argument("--gpu", action="store_true") parser.add_argument("-l", "--logging-freq", type=int, default=200) parser.add_argument("--momentum-teacher", type=int, default=0.9995) parser.add_argument("-c", "--n-crops", type=int, default=4) parser.add_argument("-e", "--n-epochs", type=int, default=100) parser.add_argument("-o", "--out-dim", type=int, default=1024) parser.add_argument("-t", "--tensorboard-dir", type=str, default="") parser.add_argument("--optimizer", type=str, default="AdamW") parser.add_argument("--clip-grad", type=float, default=2.0) parser.add_argument("--norm-last-layer", action="store_true") parser.add_argument("--batch-size-eval", type=int, default=64) parser.add_argument("--teacher-temp", type=float, default=0.04) parser.add_argument("--student-temp", type=float, default=0.1) parser.add_argument("--pretrained", action="store_true") parser.add_argument("-w", "--weight-decay", type=float, default=0.4) args = parser.parse_args() print(vars(args)) # Parameters models = { "vit_tiny": [vit_tiny, 192], "vit_small": [vit_small, 384], "vit_base": [vit_base, 768], } path_dataset_train = pathlib.Path("data/imagenette2-320/train") path_dataset_val = pathlib.Path("data/imagenette2-320/val") path_labels = pathlib.Path("data/imagenette_labels.json") if args.gpu: torch.cuda.empty_cache() torch.cuda.set_device(args.device) device = torch.cuda.current_device() print(f"Current CUDA device: {device}") else: device = torch.device("cpu") print(f"Current device: {device}") n_workers = 4 ################## # Data preparation ################## with path_labels.open("r") as f: label_mapping = json.load(f) transform_aug = DataAugmentation(size=224, n_local_crops=args.n_crops - 2) transform_plain = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), transforms.Resize((224, 224)), ]) dataset_train_aug = ImageFolder(path_dataset_train, transform=transform_aug) dataset_train_plain = ImageFolder(path_dataset_train, transform=transform_plain) dataset_val_plain = ImageFolder(path_dataset_val, transform=transform_plain) if dataset_train_plain.classes != dataset_val_plain.classes: raise ValueError("Inconsistent classes") train_dataloader_aug = DataLoader( dataset_train_aug, batch_size=args.batch_size, shuffle=True, drop_last=True, num_workers=n_workers, pin_memory=True, ) train_dataloader_plain = DataLoader( dataset_train_plain, batch_size=args.batch_size_eval, drop_last=False, num_workers=n_workers, ) val_dataloader_plain = DataLoader( dataset_val_plain, batch_size=args.batch_size_eval, drop_last=False, num_workers=n_workers, ) val_dataloader_plain_subset = DataLoader( dataset_val_plain, batch_size=args.batch_size_eval, drop_last=False, sampler=SubsetRandomSampler(list(range(0, len(dataset_val_plain), 50))), num_workers=n_workers, ) print(f"[INFO] Data loaded") ######### # Logging ######### run = neptune.init(project="beomus/dino-test") run["config/parameters"] = json.dumps(vars(args)) writer = SummaryWriter(log_dir=args.tensorboard_dir) writer.add_text("arguments", json.dumps(vars(args))) logging_path = pathlib.Path(writer.log_dir) wandb.init(project="dino", entity="beomus") wandb.config.update(args) print(f"[INFO] Logging started") ####################### # Models initialization ####################### model_fn, dim = models[args.model] student_vit = model_fn() teacher_vit = model_fn() student = MultiCropWrapper( student_vit, MlpHead(in_dim=dim, out_dim=args.out_dim, norm_last_layer=args.norm_last_layer), ) teacher = MultiCropWrapper(teacher_vit, MlpHead(dim, args.out_dim)) student, teacher = student.to(device), teacher.to(device) teacher.load_state_dict(student.state_dict()) for p in teacher.parameters(): p.requires_grad = False print(f"[INFO]: Model initialized") ###### # Loss ###### loss_inst = Loss( out_dim=args.out_dim, teacher_temp=args.teacher_temp, student_temp=args.student_temp, ).to(device) lr = 0.0005 * args.batch_size / 256 optimizer_kwargs = { "params": student.parameters(), "lr": lr, "weight_decay": args.weight_decay, "amsgrad": True, } if args.optimizer == "SGD": optimizer_kwargs["momentum"] = 0.9 optimizer_kwargs.pop("amsgrad") optimizer = getattr(torch.optim, args.optimizer)(**optimizer_kwargs) # optimizer = torch.optim.AdamW( # student.parameters(), lr=lr, weight_decay=args.weight_decay # ) model_name = f"{type(student).__name__}" with open(f"{logging_path / model_name}_arch.txt", "w") as f: f.write(str(student)) run[f"config/model/{model_name}_arch"].upload( f"{logging_path / model_name}_arch.txt") optimizer_name = f"{type(optimizer).__name__}" with open(f"{logging_path / optimizer_name}.txt", "w") as f: f.write(str(optimizer)) run[f"config/{optimizer_name}"].upload( f"{logging_path / optimizer_name}.txt") ############### # Training loop ############### n_batches = len(dataset_train_aug) // args.batch_size n_steps, best_acc = 0, 0 print(f"[INFO]: Training started") for epoch in range(args.n_epochs): for i, (images, _) in tqdm.tqdm(enumerate(train_dataloader_aug), total=n_batches): if n_steps % args.logging_freq == 0: student.eval() # embedding embs, imgs, labels_ = compute_embedding( student.backbone, val_dataloader_plain_subset) writer.add_embedding( embs, metadata=[label_mapping[l] for l in labels_], label_img=imgs, global_step=n_steps, tag="embeddings", ) # KNN current_acc = compute_knn(student.backbone, train_dataloader_plain, val_dataloader_plain) writer.add_scalar("knn-accuracy", current_acc, n_steps) run["metrics/acc"].log(current_acc) wandb.log({"accuracy": current_acc}) if current_acc > best_acc: model_path = str(logging_path / "model_best.pth") torch.save(student, model_path) run["model_checkpoints/my_model"].upload(model_path) best_acc = current_acc student.train() images = [img.to(device) for img in images] teacher_output = teacher(images[:2]) student_output = student(images) loss = loss_inst(student_output, teacher_output) optimizer.zero_grad() loss.backward() clip_gradients(student, args.clip_grad) optimizer.step() with torch.no_grad(): for student_ps, teacher_ps in zip(student.parameters(), teacher.parameters()): teacher_ps.data.mul_(args.momentum_teacher) teacher_ps.data.add_( (1 - args.momentum_teacher) * student_ps.detach().data) writer.add_scalar("train_loss", loss, n_steps) run["metrics/loss"].log(loss) wandb.log({"loss": loss}) n_steps += 1 print(f"[INFO]: Training ended") run.stop()