Ejemplo n.º 1
0
def train_one_epoch(student, teacher, teacher_without_ddp, dino_loss, data_loader,
                    optimizer, lr_schedule, wd_schedule, momentum_schedule,epoch,
                    fp16_scaler, args):
    metric_logger = utils.MetricLogger(delimiter="  ")
    header = 'Epoch: [{}/{}]'.format(epoch, args.epochs)
    for it, (images, _) in enumerate(metric_logger.log_every(data_loader, 10, header)):
        # update weight decay and learning rate according to their schedule
        it = len(data_loader) * epoch + it  # global training iteration
        for i, param_group in enumerate(optimizer.param_groups):
            param_group["lr"] = lr_schedule[it]
            if i == 0:  # only the first group is regularized
                param_group["weight_decay"] = wd_schedule[it]

        # move images to gpu
        images = [im.cuda(non_blocking=True) for im in images]
        # teacher and student forward passes + compute dino loss
        with torch.cuda.amp.autocast(fp16_scaler is not None):
            teacher_output = teacher(images[:2])  # only the 2 global views pass through the teacher
            student_output = student(images)
            loss = dino_loss(student_output, teacher_output, epoch)

        if not math.isfinite(loss.item()):
            print("Loss is {}, stopping training".format(loss.item()), force=True)
            sys.exit(1)

        # student update
        optimizer.zero_grad()
        param_norms = None
        if fp16_scaler is None:
            loss.backward()
            if args.clip_grad:
                param_norms = utils.clip_gradients(student, args.clip_grad)
            utils.cancel_gradients_last_layer(epoch, student,
                                              args.freeze_last_layer)
            optimizer.step()
        else:
            fp16_scaler.scale(loss).backward()
            if args.clip_grad:
                fp16_scaler.unscale_(optimizer)  # unscale the gradients of optimizer's assigned params in-place
                param_norms = utils.clip_gradients(student, args.clip_grad)
            utils.cancel_gradients_last_layer(epoch, student,
                                              args.freeze_last_layer)
            fp16_scaler.step(optimizer)
            fp16_scaler.update()

        # EMA update for the teacher
        with torch.no_grad():
            m = momentum_schedule[it]  # momentum parameter
            for param_q, param_k in zip(student.module.parameters(), teacher_without_ddp.parameters()):
                param_k.data.mul_(m).add_((1 - m) * param_q.detach().data)

        # logging
        torch.cuda.synchronize()
        metric_logger.update(loss=loss.item())
        metric_logger.update(lr=optimizer.param_groups[0]["lr"])
        metric_logger.update(wd=optimizer.param_groups[0]["weight_decay"])
    # gather the stats from all processes
    metric_logger.synchronize_between_processes()
    print("Averaged stats:", metric_logger)
    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
Ejemplo n.º 2
0
 def construct_graph(self, model_cls):
     # model
     self.model = model_cls()
     # loss and accuracy
     train_ops = _construct_loss_and_eval_ops_for_regression(self.model, self.train_data, is_training=True)
     self.train_loss, self.train_tr_metric, self.train_val_metric = train_ops[0], train_ops[1], train_ops[2]
     eval_ops = _construct_loss_and_eval_ops_for_regression(self.model, self.eval_data, is_training=False)
     self.eval_loss, self.eval_tr_metric, self.eval_val_metric,  self.eval_tr_input, self.eval_tr_output, \
             self.eval_tr_func, self.eval_val_input, self.eval_val_output, self.eval_val_func, self.eval_val_preds, \
             self.eval_val_sigma = eval_ops
     
     # optimisation
     training_variables = tf.trainable_variables()
     training_gradients = tf.gradients(self.train_loss, training_variables)
     training_gradients = utils.clip_gradients(training_gradients, 0, 0)
     optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=FLAGS.outer_lr)
     self.global_step = tf.compat.v1.train.get_or_create_global_step()
     self.train_op = optimizer.apply_gradients(
         list(zip(training_gradients, training_variables)), self.global_step)
Ejemplo n.º 3
0
 def construct_graph(self, model_cls):
     # construct model
     self.model = model_cls()
     # construct loss and accuracy ops
     self.train_loss, self.train_tr_metric, self.train_val_metric = \
             _construct_loss_and_eval_ops_for_classification(self.model, self.train_data, is_training=True)
     self.train_eval_loss, self.train_eval_tr_metric, self.train_eval_val_metric = \
             _construct_loss_and_eval_ops_for_classification(self.model, self.train_data, is_training=False)
     self.eval_loss,  self.eval_tr_metric, self.eval_val_metric = \
             _construct_loss_and_eval_ops_for_classification(self.model, self.eval_data, is_training=False)
     self.test_loss, self.test_tr_metric, self.test_val_metric = \
             _construct_loss_and_eval_ops_for_classification(self.model, self.test_data, is_training=False)
     # construct optimisation ops
     training_variables = tf.compat.v1.trainable_variables()
     training_gradients = tf.gradients(self.train_loss, training_variables)
     training_gradients = utils.clip_gradients(training_gradients, 0, 0) # gradient clipping is not used
     optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=FLAGS.outer_lr)
     self.global_step = tf.compat.v1.train.get_or_create_global_step()
     self.train_op = optimizer.apply_gradients(
             list(zip(training_gradients, training_variables)), self.global_step)
def train_joint(
    data_loader,
    encoder,
    decoder,
    encoder_optimizer,
    decoder_optimizer,
    epoch,
    grad_clip,
    print_freq,
    gradnorm_optimizer,
    loss_weight_generation,
    loss_weight_ranking,
    gradnorm_loss,
    gradnorm_alpha,
    initial_generation_loss,
    initial_ranking_loss,
):
    """
    Perform one training epoch.

    """

    loss_weights = [loss_weight_generation, loss_weight_ranking]

    decoder.train()
    if encoder:
        encoder.train()

    losses = AverageMeter()

    # Loop over training batches
    for i, (images, target_captions,
            caption_lengths) in enumerate(data_loader):
        target_captions = target_captions.to(device)
        caption_lengths = caption_lengths.to(device)
        images = images.to(device)

        # Forward propagation
        if encoder:
            images = encoder(images)
        decode_lengths = caption_lengths.squeeze(1) - 1

        scores, decode_lengths, images_embedded, captions_embedded, alphas = decoder.forward_joint(
            images, target_captions, decode_lengths)
        loss_generation = decoder.loss(scores, target_captions, decode_lengths,
                                       alphas)
        loss_ranking = decoder.loss_ranking(images_embedded, captions_embedded)
        loss = loss_weights[0] * loss_generation + loss_weights[
            1] * loss_ranking

        decoder_optimizer.zero_grad()
        if encoder_optimizer:
            encoder_optimizer.zero_grad()
        loss.backward(retain_graph=True)

        # Get the gradients of the shared layers and calculate their l2-norm
        named_params = dict(decoder.named_parameters())
        shared_params = [
            param for param_name, param in named_params.items()
            if param_name in decoder.SHARED_PARAMS and param.requires_grad
        ]
        G1R = torch.autograd.grad(loss_generation,
                                  shared_params,
                                  retain_graph=True,
                                  create_graph=True)
        G1R_flattened = torch.cat([g.view(-1) for g in G1R])
        G1 = torch.norm(loss_weights[0] * G1R_flattened.data, 2).unsqueeze(0)
        G2R = torch.autograd.grad(loss_ranking, shared_params)
        G2R_flattened = torch.cat([g.view(-1) for g in G2R])
        G2 = torch.norm(loss_weights[1] * G2R_flattened.data, 2).unsqueeze(0)

        # Calculate the average gradient norm across all tasks
        G_avg = torch.div(torch.add(G1, G2), 2)

        # Calculate relative losses
        lhat1 = torch.div(loss_generation, initial_generation_loss)
        lhat2 = torch.div(loss_ranking, initial_ranking_loss)
        lhat_avg = torch.div(torch.add(lhat1, lhat2), 2)

        # Calculate relative inverse training rates
        inv_rate1 = torch.div(lhat1, lhat_avg)
        inv_rate2 = torch.div(lhat2, lhat_avg)

        # Calculate the gradient norm target for this batch
        C1 = G_avg * (inv_rate1**gradnorm_alpha)
        C2 = G_avg * (inv_rate2**gradnorm_alpha)

        # Calculate the gradnorm loss
        Lgrad = torch.add(gradnorm_loss(G1, C1.data),
                          gradnorm_loss(G2, C2.data))

        # Backprop and perform an optimization step
        gradnorm_optimizer.zero_grad()
        Lgrad.backward()
        gradnorm_optimizer.step()

        # Clip gradients
        if grad_clip:
            clip_gradients(decoder_optimizer, grad_clip)
            if encoder_optimizer:
                clip_gradients(encoder_optimizer, grad_clip)

        # Update weights
        decoder_optimizer.step()
        if encoder_optimizer:
            encoder_optimizer.step()

        # Keep track of metrics
        losses.update(loss.item(), sum(decode_lengths).item())

        # Log status
        if i % print_freq == 0:
            logging.info(
                "Epoch: {0}[Batch {1}/{2}]\t"
                "Loss: {loss.val:.4f} (Average: {loss.avg:.4f})\t Loss weights: Generation: {3:.4f} Ranking: {4:.4f}"
                .format(
                    epoch,
                    i,
                    len(data_loader),
                    loss_weights[0].item(),
                    loss_weights[1].item(),
                    loss=losses,
                ))

        # Renormalize the gradnorm weights
        coef = 2 / torch.add(loss_weight_generation, loss_weight_ranking)
        loss_weights = [
            coef * loss_weight_generation, coef * loss_weight_ranking
        ]

    logging.info("\n * LOSS - {loss.avg:.3f}\n".format(loss=losses))
Ejemplo n.º 5
0
def train(model, train_loader, val_loader, logger, optim, output, gpu = True, **train_config):
    # 定义优化器,学习率
    # torch.autograd.set_detect_anomaly(True)
    accuracy = VQAAccuracy()          # 计算精度
    
#     class_frequence = torch.zeros_like(train_loader.dataset[0]['answer'][:-50])
#     for sample in train_loader.dataset:
#         class_frequence +=  torch.ceil(sample['answer'][:-50])
#     class_weight = torch.exp(-class_frequence/class_frequence.sum())
#     print(class_weight)
    
    lbce =  LogitBinaryCrossEntropy()
    
    # 设置优化函数进行学习率
    scheduler_func = lambda x: lr_lambda_update(x, **train_config)
    lr_scheduler = LambdaLR(optim, lr_lambda = scheduler_func)
    
#     lr_scheduler = ExponentialLR(optim, 0.5**(1/50000))
    
    iteration = 0
    best_val_accuracy = 0
    best_epoch = 0
    patient = 0
    saving_epoch = 4
    
    log_train_accuracy = {}
    log_train_loss = {}
    if val_loader is not None:
        log_val_loss = {}
        log_val_accuracy = {}
    
    for epoch in range(1,train_config["epoch"]+1):
#     while iteration < train_config["max_iterations"]:
        model.train()
        log_msg = "Epoch %d of Train:"%(
            epoch
        )
        print(log_msg)
        total_accuracy = 0
        total_loss = 0
        for i, sample in enumerate(train_loader):
            iteration += 1
            input_ids = Variable(sample["input_ids"])
            token_type_ids = Variable(sample["token_type_ids"])
            attention_mask = Variable(sample["attention_mask"])
            
            img = Variable(sample["img_feature"])
#             context = Variable(torch.zeros_like(sample["context_feature"]))
            context = Variable(sample["context_feature"])
            labels = Variable(sample['answer'])
            bbox = Variable(sample['bbox'])
            ocrbbox = Variable(sample['ocrbbox'])
            
            if gpu:
                input_ids = input_ids.cuda()   
                token_type_ids = token_type_ids.cuda()
                attention_mask = attention_mask.cuda()
                img = img.cuda()
                context = context.cuda()
                labels = labels.cuda()
                bbox = bbox.cuda()
                ocrbbox = ocrbbox.cuda()
            
            batch_size = img.size(0)
            prediction = model(img, bbox, input_ids, token_type_ids, attention_mask, context, ocrbbox)
            
            loss, lc, lo = lbce(labels, prediction)
            if epoch<=13:
                loss.backward()
            elif epoch%2 ==0 :
                lc.backward()
            else :
                lo.backward()
#             if (i+1)%2==0:
            lr_scheduler.step(epoch)
            clip_gradients(model, train_config["max_grad_l2_norm"], train_config["clip_norm_mode"])
            optim.step()
            optim.zero_grad()
            # 统计精度
            batch_accuracy = accuracy(labels.data,prediction.data)
            total_accuracy += batch_accuracy * batch_size
            # 统计loss
            total_loss += loss.data
            if (i+1)%10 == 0:
                log_msg = "[%d/%d/%d] iter:%d accuracy:%2.2f loss:%.5f lr: %f"%(
                    epoch, 
                    len(train_loader),
                    i,
                    iteration, 
                    batch_accuracy*100,
                    loss.data,
                    optim.param_groups[0]['lr']
                )
                print(log_msg)
        if val_loader is not None:
            log_msg = "Epoch %d of Valuation:"%(
                epoch
            )
            print(log_msg)
            val_accuracy, val_loss = evaluate(model, val_loader, logger, gpu, **train_config)
        print("Result")
        
        log_msg = "Train accuracy:%2.2f, Train loss: %.5f"%(
            total_accuracy/len(train_loader.dataset)*100, 
            total_loss/len(train_loader.dataset)
        )
        if val_loader is not None:
            log_msg += ", Val accuracy:%2.2f, Val loss: %.5f" % (val_accuracy*100, val_loss)
        print(log_msg)
        
        log_train_accuracy["epoch "+str(epoch)] = total_accuracy.cpu().numpy().tolist()/len(train_loader.dataset)*100
        log_train_loss["epoch "+str(epoch)] = total_loss.cpu().numpy().tolist()/len(train_loader.dataset)
        
        if val_loader is not None:
            log_val_accuracy["epoch "+str(epoch)] = val_accuracy.cpu().numpy().tolist()*100
            log_val_loss["epoch "+str(epoch)] = val_loss.cpu().numpy().tolist()
        
        if (val_loader is not None and val_accuracy > best_val_accuracy) or (val_loader is None and epoch >= saving_epoch):
            model_name = 'model_%s.pth'%('epoch'+str(epoch) if val_loader is None else 'best')
            model_path = os.path.join(output, model_name)
            utils.save_model(model_path, model, epoch, optim)
            if val_loader is not None:
                best_val_accuracy = val_accuracy
                best_epoch = epoch
                patient = 0
        elif val_loader is not None:
            patient += 1
            if patient >= 15:
                print("Patient %d early stop!!"%patient)
                break
        print("Patient %d"%patient)
    
    log_msg = "best val accuracy : %2.2f at %d epoch. "%( best_val_accuracy*100, best_epoch)
    print(log_msg)
    logger.add("train accuracy", log_train_accuracy)  
    logger.add("train loss", log_train_loss)
    if val_loader is not None:
        logger.add("best val accuracy", best_val_accuracy.cpu().numpy().tolist())
        logger.add("best_epoch", best_epoch)
        logger.add("val loss", log_val_loss)
        logger.add("val accuracy", log_val_accuracy)  
    logger.save_log()
def train(
    model_name,
    data_loader,
    encoder,
    decoder,
    encoder_optimizer,
    decoder_optimizer,
    epoch,
    grad_clip,
    print_freq,
):
    """
    Perform one training epoch.

    """

    decoder.train()
    if encoder:
        encoder.train()

    losses = AverageMeter()

    # Loop over training batches
    for i, (images, target_captions,
            caption_lengths) in enumerate(data_loader):
        target_captions = target_captions.to(device)
        caption_lengths = caption_lengths.to(device)
        images = images.to(device)

        # Forward propagation
        if encoder:
            images = encoder(images)
        decode_lengths = caption_lengths.squeeze(1) - 1

        if model_name == MODEL_BOTTOM_UP_TOP_DOWN_RANKING:
            scores, decode_lengths, images_embedded, captions_embedded, alphas = decoder.forward_joint(
                images, target_captions, decode_lengths)
            loss = decoder.loss(scores, target_captions, decode_lengths,
                                alphas)

        else:
            scores, decode_lengths, alphas = decoder(images, target_captions,
                                                     decode_lengths)
            loss = decoder.loss(scores, target_captions, decode_lengths,
                                alphas)

        decoder_optimizer.zero_grad()
        if encoder_optimizer:
            encoder_optimizer.zero_grad()
        loss.backward()

        # Clip gradients
        if grad_clip:
            clip_gradients(decoder_optimizer, grad_clip)
            if encoder_optimizer:
                clip_gradients(encoder_optimizer, grad_clip)

        # Update weights
        decoder_optimizer.step()
        if encoder_optimizer:
            encoder_optimizer.step()

        # Keep track of metrics
        losses.update(loss.item(), sum(decode_lengths).item())

        # Log status
        if i % print_freq == 0:
            logging.info(
                "Epoch: {0}[Batch {1}/{2}]\t"
                "Loss: {loss.val:.4f} (Average: {loss.avg:.4f})\t".format(
                    epoch, i, len(data_loader), loss=losses))

    logging.info("\n * LOSS - {loss.avg:.3f}\n".format(loss=losses))
Ejemplo n.º 7
0
def main(argv):
    del argv  # unused
    np.random.seed(FLAGS.seed)
    #tf.compat.v1.set_random_seed(FLAGS.seed)

    print("Testing: ", FLAGS.testing, f"\t Seed: {FLAGS.seed}")

    FLAGS.encoder_sizes = [int(size) for size in FLAGS.encoder_sizes]
    FLAGS.decoder_sizes = [int(size) for size in FLAGS.decoder_sizes]

    if 0 in FLAGS.encoder_sizes:
        FLAGS.encoder_sizes.remove(0)
    if 0 in FLAGS.decoder_sizes:
        FLAGS.decoder_sizes.remove(0)

    # Make up full exp name
    timestamp = datetime.now().strftime("%y%m%d")
    full_exp_name = "{}_{}".format(timestamp, FLAGS.exp_name)
    outdir = os.path.join(FLAGS.basedir, full_exp_name)
    if not os.path.exists(outdir):
        os.makedirs(outdir)
    checkpoint_prefix = os.path.join(outdir, "ckpt")
    print("Full exp name: ", full_exp_name)

    ###################################
    # Define data specific parameters #
    ###################################

    if FLAGS.data_type == "hmnist":
        FLAGS.data_dir = "data/hmnist/hmnist_mnar.npz"
        data_dim = 784
        time_length = 10
        num_classes = 10
        decoder = BernoulliDecoder
        img_shape = (28, 28, 1)
        val_split = 50000
    elif FLAGS.data_type == "physionet":
        if FLAGS.data_dir == "":
            FLAGS.data_dir = "data/physionet/physionet.npz"
        data_dim = 35
        time_length = 48
        num_classes = 2

        decoder = GaussianDecoder
    elif FLAGS.data_type == "sprites":
        if FLAGS.data_dir == "":
            FLAGS.data_dir = "data/sprites/sprites.npz"
        data_dim = 12288
        time_length = 8
        decoder = GaussianDecoder
        img_shape = (64, 64, 3)
        val_split = 8000
    else:
        raise ValueError(
            "Data type must be one of ['hmnist', 'physionet', 'sprites']")

    #############
    # Load data #
    #############

    data = np.load(FLAGS.data_dir)
    x_train_full = data['x_train_full']
    x_train_miss = data['x_train_miss']
    m_train_miss = data['m_train_miss']
    if FLAGS.data_type in ['hmnist', 'physionet']:
        y_train = data['y_train']

    if FLAGS.testing:
        if FLAGS.data_type in ['hmnist', 'sprites']:
            x_val_full = data['x_test_full']
            x_val_miss = data['x_test_miss']
            m_val_miss = data['m_test_miss']
        if FLAGS.data_type == 'hmnist':
            y_val = data['y_test']
        elif FLAGS.data_type == 'physionet':
            x_val_full = data['x_train_full']
            x_val_miss = data['x_train_miss']
            m_val_miss = data['m_train_miss']
            y_val = data['y_train']
            m_val_artificial = data["m_train_artificial"]
    elif FLAGS.data_type in ['hmnist', 'sprites']:
        x_val_full = x_train_full[val_split:]
        x_val_miss = x_train_miss[val_split:]
        m_val_miss = m_train_miss[val_split:]
        if FLAGS.data_type == 'hmnist':
            y_val = y_train[val_split:]
        x_train_full = x_train_full[:val_split]
        x_train_miss = x_train_miss[:val_split]
        m_train_miss = m_train_miss[:val_split]
        y_train = y_train[:val_split]
    elif FLAGS.data_type == 'physionet':
        x_val_full = data["x_val_full"]  # full for artificial missings
        x_val_miss = data["x_val_miss"]
        m_val_miss = data["m_val_miss"]
        m_val_artificial = data["m_val_artificial"]
        y_val = data["y_val"]
    else:
        raise ValueError(
            "Data type must be one of ['hmnist', 'physionet', 'sprites']")

    tf_x_train_miss = DataLoader(MyDataset(x_train_miss, m_train_miss),
                                 shuffle=True,
                                 batch_size=FLAGS.batch_size)
    tf_x_val_miss = iter(
        DataLoader(MyDataset(x_val_miss, m_val_miss),
                   batch_size=FLAGS.batch_size,
                   shuffle=False))
    tf_x_test_miss = DataLoader(MyDataset(x_val_miss, m_val_miss),
                                batch_size=len(x_val_miss),
                                shuffle=False)

    # Build Conv2D preprocessor for image data
    if FLAGS.data_type in ['hmnist', 'sprites']:
        print("Using CNN preprocessor")
        image_preprocessor = ImagePreprocessor(img_shape, FLAGS.cnn_sizes,
                                               FLAGS.cnn_kernel_size)
    elif FLAGS.data_type == 'physionet':
        image_preprocessor = None
    else:
        raise ValueError(
            "Data type must be one of ['hmnist', 'physionet', 'sprites']")

    ###############
    # Build model #
    ###############

    if FLAGS.model_type == "vae":
        model = VAE(latent_dim=FLAGS.latent_dim,
                    data_dim=data_dim,
                    time_length=time_length,
                    encoder_sizes=FLAGS.encoder_sizes,
                    encoder=DiagonalEncoder,
                    decoder_sizes=FLAGS.decoder_sizes,
                    decoder=decoder,
                    image_preprocessor=image_preprocessor,
                    window_size=FLAGS.window_size,
                    beta=FLAGS.beta,
                    M=FLAGS.M,
                    K=FLAGS.K)
    elif FLAGS.model_type == "hi-vae":
        model = HI_VAE(latent_dim=FLAGS.latent_dim,
                       data_dim=data_dim,
                       time_length=time_length,
                       encoder_sizes=FLAGS.encoder_sizes,
                       encoder=DiagonalEncoder,
                       decoder_sizes=FLAGS.decoder_sizes,
                       decoder=decoder,
                       image_preprocessor=image_preprocessor,
                       window_size=FLAGS.window_size,
                       beta=FLAGS.beta,
                       M=FLAGS.M,
                       K=FLAGS.K)
    elif FLAGS.model_type == "gp-vae":
        encoder = BandedJointEncoder if FLAGS.banded_covar else JointEncoder
        model = GP_VAE(latent_dim=FLAGS.latent_dim,
                       data_dim=data_dim,
                       time_length=time_length,
                       encoder_sizes=FLAGS.encoder_sizes,
                       encoder=encoder,
                       decoder_sizes=FLAGS.decoder_sizes,
                       decoder=decoder,
                       kernel=FLAGS.kernel,
                       sigma=FLAGS.sigma,
                       length_scale=FLAGS.length_scale,
                       kernel_scales=FLAGS.kernel_scales,
                       image_preprocessor=image_preprocessor,
                       window_size=FLAGS.window_size,
                       beta=FLAGS.beta,
                       M=FLAGS.M,
                       K=FLAGS.K)
    else:
        raise ValueError(
            "Model type must be one of ['vae', 'hi-vae', 'gp-vae']")

    clip_gradients(model, FLAGS.gradient_clip)
    ########################
    # Training preparation #
    ########################

    print("GPU support: ", tf.test.is_gpu_available())

    print("Training...")
    #_ = tf.compat.v1.train.get_or_create_global_step()
    #trainable_vars = model.get_trainable_vars()
    #optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=FLAGS.learning_rate)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    print("Encoder: ", model.encoder)
    print("Decoder: ", model.decoder)

    if model.preprocessor is not None:
        print("Preprocessor: ", model.preprocessor.net.summary())
        #saver = tf.compat.v1.train.Checkpoint(optimizer=optimizer, encoder=model.encoder.net,
        #                                      decoder=model.decoder.net, preprocessor=model.preprocessor.net,
        #                                      optimizer_step=tf.compat.v1.train.get_or_create_global_step())
    else:
        #saver = tf.compat.v1.train.Checkpoint(optimizer=optimizer, encoder=model.encoder.net, decoder=model.decoder.net,
        #                                      optimizer_step=tf.compat.v1.train.get_or_create_global_step())
        pass
    #summary_writer = tf.contrib.summary.create_file_writer(outdir, flush_millis=10000)

    if FLAGS.num_steps == 0:
        num_steps = FLAGS.num_epochs * len(x_train_miss) // FLAGS.batch_size
    else:
        num_steps = FLAGS.num_steps

    if FLAGS.print_interval == 0:
        FLAGS.print_interval = num_steps // FLAGS.num_epochs

    ############
    # Training #
    ############

    losses_train = []
    losses_val = []

    t0 = time.time()
    #with summary_writer.as_default(), tf.contrib.summary.always_record_summaries():
    for i, (x_seq, m_seq) in enumerate(tf_x_train_miss):
        if i >= num_steps:
            break
        try:
            #print(x_seq.shape)
            optimizer.zero_grad()
            #with tf.GradientTape() as tape:
            #    tape.watch(trainable_vars)

            loss = model.compute_loss(x_seq, m_mask=m_seq)
            losses_train.append(loss.detach().numpy())
            #grads = loss.grad
            #print(grads)
            #grads = tape.gradient(loss, trainable_vars)
            #grads = [np.nan_to_num(grad) for grad in grads]
            #t#orch.nn.utils.clip_grad_norm(model.parameters(),clip)
            #grads, global_norm = tf.clip_by_global_norm(grads, FLAGS.gradient_clip)
            optimizer.step()
            #optimizer.apply_gradients(zip(grads, trainable_vars),
            #                          global_step=tf.compat.v1.train.get_or_create_global_step())

            # Print intermediate results
            if i % FLAGS.print_interval == 0:
                print("================================================")
                print("Learning rate: {} | Global gradient norm: ".format(
                    optimizer.param_groups[0]['lr']))
                print("Step {}) Time = {:2f}".format(i, time.time() - t0))
                loss, nll, kl = model.compute_loss(x_seq,
                                                   m_mask=m_seq,
                                                   return_parts=True)
                print(
                    "Train loss = {:.3f} | NLL = {:.3f} | KL = {:.3f}".format(
                        loss, nll, kl))
                torch.save(model, checkpoint_prefix)

                #saver.save(checkpoint_prefix)
                print("loss_train", loss)
                print("kl_train", kl)
                print("nll_train", nll)
                #tf.contrib.summary.scalar("nll_train", nll)

                # Validation loss
                x_val_batch, m_val_batch = next(tf_x_val_miss)
                val_loss, val_nll, val_kl = model.compute_loss(
                    x_val_batch, m_mask=m_val_batch, return_parts=True)
                losses_val.append(val_loss.detach().numpy())
                print("Validation loss = {:.3f} | NLL = {:.3f} | KL = {:.3f}".
                      format(val_loss, val_nll, val_kl))

                print("loss_val", val_loss)
                print("kl_val", val_kl)
                print("nll_val", val_nll)

                if FLAGS.data_type in ["hmnist", "sprites"]:
                    # Draw reconstructed images
                    x_hat = model.decode(model.encode(x_seq).sample()).mean
                    tf.contrib.summary.image(
                        "input_train", x_seq.reshape([-1] + list(img_shape)))
                    tf.contrib.summary.image(
                        "reconstruction_train",
                        x_hat.reshape([-1] + list(img_shape)))
                elif FLAGS.data_type == 'physionet':
                    # Eval MSE and AUROC on entire val set
                    x_val_miss_batches = np.array_split(x_val_miss,
                                                        FLAGS.batch_size,
                                                        axis=0)
                    x_val_full_batches = np.array_split(x_val_full,
                                                        FLAGS.batch_size,
                                                        axis=0)
                    m_val_artificial_batches = np.array_split(m_val_artificial,
                                                              FLAGS.batch_size,
                                                              axis=0)
                    get_val_batches = lambda: zip(x_val_miss_batches,
                                                  x_val_full_batches,
                                                  m_val_artificial_batches)

                    n_missings = m_val_artificial.sum()
                    mse_miss = np.sum([
                        model.compute_mse(x, y=y, m_mask=m).item()
                        for x, y, m in get_val_batches()
                    ]) / n_missings

                    x_val_imputed = np.vstack([
                        model.decode(
                            model.encode(x_batch).mean).mean.detach().numpy()
                        for x_batch in x_val_miss_batches
                    ])
                    x_val_imputed[m_val_miss == 0] = x_val_miss[
                        m_val_miss == 0]  # impute gt observed values

                    x_val_imputed = x_val_imputed.reshape(
                        [-1, time_length * data_dim])
                    val_split = len(x_val_imputed) // 2
                    cls_model = LogisticRegression(solver='liblinear',
                                                   tol=1e-10,
                                                   max_iter=10000)
                    cls_model.fit(x_val_imputed[:val_split], y_val[:val_split])
                    probs = cls_model.predict_proba(
                        x_val_imputed[val_split:])[:, 1]
                    auroc = roc_auc_score(y_val[val_split:], probs)
                    print("MSE miss: {:.4f} | AUROC: {:.4f}".format(
                        mse_miss, auroc))

                    # Update learning rate (used only for physionet with decay=0.5)
                    if i > 0 and i % (10 * FLAGS.print_interval) == 0:
                        optimizer._lr = max(0.5 * optimizer._lr,
                                            0.1 * FLAGS.learning_rate)
                t0 = time.time()
        except KeyboardInterrupt:
            saver.save(checkpoint_prefix)
            if FLAGS.debug:
                import ipdb
                ipdb.set_trace()
            break

    ##############
    # Evaluation #
    ##############

    print("Evaluation...")

    # Split data on batches
    x_val_miss_batches = np.array_split(x_val_miss, FLAGS.batch_size, axis=0)
    x_val_full_batches = np.array_split(x_val_full, FLAGS.batch_size, axis=0)
    if FLAGS.data_type == 'physionet':
        m_val_batches = np.array_split(m_val_artificial,
                                       FLAGS.batch_size,
                                       axis=0)
    else:
        m_val_batches = np.array_split(m_val_miss, FLAGS.batch_size, axis=0)
    get_val_batches = lambda: zip(x_val_miss_batches, x_val_full_batches,
                                  m_val_batches)

    # Compute NLL and MSE on missing values
    n_missings = m_val_artificial.sum(
    ) if FLAGS.data_type == 'physionet' else m_val_miss.sum()
    nll_miss = np.sum([
        model.compute_nll(x, y=y, m_mask=m).item()  #.detach().numpy()
        for x, y, m in get_val_batches()
    ]) / n_missings
    mse_miss = np.sum([
        model.compute_mse(x, y=y, m_mask=m, binary=FLAGS.data_type
                          == "hmnist").item()  #.detach().numpy()
        for x, y, m in get_val_batches()
    ]) / n_missings
    print("NLL miss: {:.4f}".format(nll_miss))
    print("MSE miss: {:.4f}".format(mse_miss))

    # Save imputed values
    z_mean = [
        model.encode(x_batch).mean.detach().numpy()
        for x_batch in x_val_miss_batches
    ]
    np.save(os.path.join(outdir, "z_mean"), np.vstack(z_mean))
    x_val_imputed = np.vstack(
        [model.decode(z_batch).mean.detach().numpy() for z_batch in z_mean])
    np.save(os.path.join(outdir, "imputed_no_gt"), x_val_imputed)

    # impute gt observed values
    x_val_imputed[m_val_miss == 0] = x_val_miss[m_val_miss == 0]
    np.save(os.path.join(outdir, "imputed"), x_val_imputed)

    if FLAGS.data_type == "hmnist":
        # AUROC evaluation using Logistic Regression
        x_val_imputed = np.round(x_val_imputed)
        x_val_imputed = x_val_imputed.reshape([-1, time_length * data_dim])

        cls_model = LogisticRegression(solver='lbfgs',
                                       multi_class='multinomial',
                                       tol=1e-10,
                                       max_iter=10000)
        val_split = len(x_val_imputed) // 2

        cls_model.fit(x_val_imputed[:val_split], y_val[:val_split])
        probs = cls_model.predict_proba(x_val_imputed[val_split:])

        auprc = average_precision_score(
            np.eye(num_classes)[y_val[val_split:]], probs)
        auroc = roc_auc_score(np.eye(num_classes)[y_val[val_split:]], probs)
        print("AUROC: {:.4f}".format(auroc))
        print("AUPRC: {:.4f}".format(auprc))

    elif FLAGS.data_type == "sprites":
        auroc, auprc = 0, 0

    elif FLAGS.data_type == "physionet":
        # Uncomment to preserve some z_samples and their reconstructions
        # for i in range(5):
        #     z_sample = [model.encode(x_batch).sample().numpy() for x_batch in x_val_miss_batches]
        #     np.save(os.path.join(outdir, "z_sample_{}".format(i)), np.vstack(z_sample))
        #     x_val_imputed_sample = np.vstack([model.decode(z_batch).mean().numpy() for z_batch in z_sample])
        #     np.save(os.path.join(outdir, "imputed_sample_{}_no_gt".format(i)), x_val_imputed_sample)
        #     x_val_imputed_sample[m_val_miss == 0] = x_val_miss[m_val_miss == 0]
        #     np.save(os.path.join(outdir, "imputed_sample_{}".format(i)), x_val_imputed_sample)

        # AUROC evaluation using Logistic Regression
        x_val_imputed = x_val_imputed.reshape([-1, time_length * data_dim])
        val_split = len(x_val_imputed) // 2
        cls_model = LogisticRegression(solver='liblinear',
                                       tol=1e-10,
                                       max_iter=10000)
        cls_model.fit(x_val_imputed[:val_split], y_val[:val_split])
        probs = cls_model.predict_proba(x_val_imputed[val_split:])[:, 1]
        auprc = average_precision_score(y_val[val_split:], probs)
        auroc = roc_auc_score(y_val[val_split:], probs)

        print("AUROC: {:.4f}".format(auroc))
        print("AUPRC: {:.4f}".format(auprc))

    # Visualize reconstructions
    if FLAGS.data_type in ["hmnist", "sprites"]:
        img_index = 0
        if FLAGS.data_type == "hmnist":
            img_shape = (28, 28)
            cmap = "gray"
        elif FLAGS.data_type == "sprites":
            img_shape = (64, 64, 3)
            cmap = None

        fig, axes = plt.subplots(nrows=3,
                                 ncols=x_val_miss.shape[1],
                                 figsize=(2 * x_val_miss.shape[1], 6))

        x_hat = model.decode(
            model.encode(x_val_miss[img_index:img_index +
                                    1]).mean()).mean().numpy()
        seqs = [
            x_val_miss[img_index:img_index + 1], x_hat,
            x_val_full[img_index:img_index + 1]
        ]

        for axs, seq in zip(axes, seqs):
            for ax, img in zip(axs, seq[0]):
                ax.imshow(img.reshape(img_shape), cmap=cmap)
                ax.axis('off')

        suptitle = FLAGS.model_type + f" reconstruction, NLL missing = {mse_miss}"
        fig.suptitle(suptitle, size=18)
        fig.savefig(
            os.path.join(outdir, FLAGS.data_type + "_reconstruction.pdf"))

    results_all = [
        FLAGS.seed, FLAGS.model_type, FLAGS.data_type, FLAGS.kernel,
        FLAGS.beta, FLAGS.latent_dim, FLAGS.num_epochs, FLAGS.batch_size,
        FLAGS.learning_rate, FLAGS.window_size, FLAGS.kernel_scales,
        FLAGS.sigma, FLAGS.length_scale,
        len(FLAGS.encoder_sizes),
        FLAGS.encoder_sizes[0] if len(FLAGS.encoder_sizes) > 0 else 0,
        len(FLAGS.decoder_sizes),
        FLAGS.decoder_sizes[0] if len(FLAGS.decoder_sizes) > 0 else 0,
        FLAGS.cnn_kernel_size, FLAGS.cnn_sizes, nll_miss, mse_miss,
        losses_train[-1], losses_val[-1], auprc, auroc, FLAGS.testing,
        FLAGS.data_dir
    ]

    with open(os.path.join(outdir, "results.tsv"), "w") as outfile:
        outfile.write(
            "seed\tmodel\tdata\tkernel\tbeta\tz_size\tnum_epochs"
            "\tbatch_size\tlearning_rate\twindow_size\tkernel_scales\t"
            "sigma\tlength_scale\tencoder_depth\tencoder_width\t"
            "decoder_depth\tdecoder_width\tcnn_kernel_size\t"
            "cnn_sizes\tNLL\tMSE\tlast_train_loss\tlast_val_loss\tAUPRC\tAUROC\ttesting\tdata_dir\n"
        )
        outfile.write("\t".join(map(str, results_all)))

    with open(os.path.join(outdir, "training_curve.tsv"), "w") as outfile:
        outfile.write("\t".join(map(str, losses_train)))
        outfile.write("\n")
        outfile.write("\t".join(map(str, losses_val)))

    print("Training finished.")
Ejemplo n.º 8
0
def main():
    is_training = tf.placeholder(tf.bool, name='is_training')
    num_class = args.num_way
    num_shot = args.num_shot
    num_query = args.num_query

    keep_prob = tf.placeholder(tf.float32, name='keep_prob')
    support_label = tf.placeholder(tf.int32, (None, ), 'support_label')
    query_label = tf.placeholder(tf.int32, (None, ), 'query_label')

    support_x = tf.placeholder(tf.float32, (None, 640), 'support_x')
    query_x = tf.placeholder(tf.float32, (None, 640), 'query_x')

    support_feature = support_x
    query_feature = query_x
    support_feature = tf.reshape(support_feature,
                                 (batch_size, num_class, num_shot, 640))
    query_feature = tf.reshape(query_feature,
                               (batch_size, num_class, num_query, 640))
    support_label_reshape = tf.reshape(support_label,
                                       (batch_size, num_class, num_shot))
    query_label_reshape = tf.reshape(query_label,
                                     (batch_size, num_class, num_query))

    awgim = model.AWGIM(args, keep_prob, is_training)
    loss_cls, accuracy, tr_loss, tr_accuracy, support_reconstruction, query_reconstruction = \
        awgim.forward(support_feature, support_label_reshape, query_feature, query_label_reshape)
    reg_term = tf.add_n([
        tf.nn.l2_loss(v) for v in tf.trainable_variables()
        if 'kernel' in v.name
    ])
    loss_meta = loss_cls + args.alpha_1 * tr_loss + args.alpha_2 * support_reconstruction + args.alpha_3 * query_reconstruction
    Batch = tf.Variable(0,
                        trainable=False,
                        dtype=tf.float32,
                        name='global_step')
    learning_rate = tf.train.exponential_decay(
        learning_rate=args.learning_rate,
        global_step=Batch,
        decay_steps=args.step_size,
        decay_rate=0.2,
        staircase=True)
    optim = tf.contrib.opt.AdamWOptimizer(learning_rate=learning_rate,
                                          weight_decay=args.weight_decay)
    meta_weights = [v for v in tf.trainable_variables()]
    print(meta_weights)

    if args.stage == 'train':
        meta_gradients = utils.grads_and_vars(loss_meta, meta_weights,
                                              reg_term)
        meta_gradients = utils.clip_gradients(meta_gradients,
                                              args.gradient_threshold,
                                              args.gradient_norm_threshold)
        train_op = optim.apply_gradients(zip(meta_gradients, meta_weights),
                                         global_step=Batch)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    saver = tf.train.Saver()
    sess.run(tf.global_variables_initializer())

    save_path = utils.save(args)
    print(save_path)
    os.makedirs(save_path, exist_ok=True)
    if args.stage == 'test':
        print(tf.train.latest_checkpoint(save_path))
        saver.restore(sess, tf.train.latest_checkpoint(save_path))
        print('load model')
    if args.data_set == 'mini':
        loader_train = dataset_mini.dataset_mini('train', args)
        loader_val = dataset_mini.dataset_mini('val', args)
        loader_test = dataset_mini.dataset_mini('test', args)
    else:
        loader_train = dataset_tiered.dataset_tiered('train', args)
        loader_val = dataset_tiered.dataset_tiered('val', args)
        loader_test = dataset_tiered.dataset_tiered('test', args)

    if args.stage == 'train':
        print('Load PKL data')
        loader_train.load_data_pkl()
        loader_val.load_data_pkl()
    else:
        loader_test.load_data_pkl()

    val_best_accuracy = 0.
    n_iter = 0
    record_val_acc = []
    if args.stage == 'train':
        for epoch in range(args.epoch):
            training_accuracy, training_loss, acc_cp, acc_real, c_loss, d_loss, g_loss = [], [], [], [], [], [], []
            # training_loss_cls = []
            for epi in range(100):
                support_input, s_labels, query_input, q_labels = utils.load_batch(
                    args, loader_train, args.batch_size, True, loader_val)
                feed_dict = {
                    support_x: support_input,
                    support_label: s_labels,
                    query_x: query_input,
                    query_label: q_labels,
                    is_training: True,
                    keep_prob: 1. - args.dropout
                }
                outs = sess.run([train_op, loss_meta, accuracy, Batch],
                                feed_dict=feed_dict)
                training_accuracy.append(outs[2])
                training_loss.append(outs[1])
                n_iter += 1
            if (epoch + 1) % 3 == 0:
                log = 'epoch: ', epoch + 1, 'accuracy: ', np.mean(
                    training_accuracy), 'loss: ', np.mean(training_loss)
                print(log)
            if (epoch + 1) % 3 == 0:
                accuracy_val = []
                loss_val = []
                for epi in range(100):
                    support_input, s_labels, query_input, q_labels = utils.load_batch(
                        args, loader_val, args.batch_size, training=False)
                    outs = sess.run(
                        [loss_meta, accuracy, Batch],
                        feed_dict={
                            support_x: support_input,
                            support_label: s_labels,
                            query_x: query_input,
                            query_label: q_labels,
                            is_training: False,
                            keep_prob: 1.
                        })
                    accuracy_val.append(outs[1])
                    loss_val.append(outs[0])
                mean_acc = np.mean(accuracy_val)
                std_acc = np.std(accuracy_val)
                ci95 = 1.96 * std_acc / np.sqrt(100)
                print(
                    ' Val Acc:{:.4f},std:{:.4f},ci95:{:.4f}'.format(
                        mean_acc, std_acc, ci95), 'at epoch: ', epoch + 1)
                record_val_acc.append(mean_acc)
                if mean_acc > val_best_accuracy:
                    val_best_accuracy = mean_acc
                    saver.save(sess,
                               save_path=save_path + 'model.ckpt',
                               global_step=Batch)
            if (epoch + 1) % 100 == 0:
                saver.save(sess,
                           save_path=save_path + 'model.ckpt',
                           global_step=Batch)
    elif args.stage == 'test':
        accuracy_test = []
        loss_test = []
        num = 600
        for epi in range(num):
            support_input, s_labels, query_input, q_labels = utils.load_batch(
                args, loader_test, args.batch_size, False)
            outs = sess.run(
                [loss_meta, accuracy],
                feed_dict={
                    support_x: support_input,
                    support_label: s_labels,
                    query_x: query_input,
                    query_label: q_labels,
                    is_training: False,
                    keep_prob: 1.
                })
            accuracy_test.append(outs[1])
            loss_test.append(outs[0])
        mean_acc = np.mean(accuracy_test)
        std_acc = np.std(accuracy_test)
        ci95 = 1.96 * std_acc / np.sqrt(num)
        print('Acc:{:.4f},std:{:.4f},ci95:{:.4f}'.format(
            mean_acc, std_acc, ci95))

    sess.close()
Ejemplo n.º 9
0
def main():
    parser = argparse.ArgumentParser(
        "DINO training CLI",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    parser.add_argument("-b", "--batch-size", type=int, default=4)
    parser.add_argument("-d",
                        "--device",
                        type=str,
                        choices=("cpu", "cuda"),
                        default="cuda")
    parser.add_argument("-l", "--logging-freq", type=int, default=200)
    parser.add_argument("--momentum-teacher", type=int, default=0.9995)
    parser.add_argument("-c", "--n-crops", type=int, default=4)
    parser.add_argument("-e", "--n-epochs", type=int, default=100)
    parser.add_argument("-o", "--out-dim", type=int, default=1024)
    parser.add_argument("-t", "--tensorboard-dir", type=str, default="logs")
    parser.add_argument("--clip-grad", type=float, default=2.0)
    parser.add_argument("--norm-last-layer", action="store_true")
    parser.add_argument("--batch-size-eval", type=int, default=8)
    parser.add_argument("--teacher-temp", type=float, default=0.04)
    parser.add_argument("--student-temp", type=float, default=0.1)
    parser.add_argument("--pretrained", action="store_true")
    parser.add_argument("-w", "--weight-decay", type=float, default=0.4)

    args = parser.parse_args()
    print(vars(args))
    # Parameters
    vit_name, dim = "deit_small_patch16_224", 384
    path_dataset_train = pathlib.Path("data/imagenette2-320/train")
    path_dataset_val = pathlib.Path("data/imagenette2-320/val")
    path_labels = pathlib.Path("data/imagenette_labels.json")

    logging_path = pathlib.Path(args.tensorboard_dir)
    device = torch.device(args.device)

    n_workers = 1  # para mi maquinita solo 2 como maximo

    # Data related
    with path_labels.open("r") as f:
        label_mapping = json.load(f)

    transform_aug = DataAugmentation(size=224, n_local_crops=args.n_crops - 2)
    transform_plain = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        transforms.Resize((224, 224)),
    ])

    dataset_train_aug = ImageFolder(path_dataset_train,
                                    transform=transform_aug)
    dataset_train_plain = ImageFolder(path_dataset_train,
                                      transform=transform_plain)
    dataset_val_plain = ImageFolder(path_dataset_val,
                                    transform=transform_plain)

    if dataset_train_plain.classes != dataset_val_plain.classes:
        raise ValueError("Inconsistent classes")

    data_loader_train_aug = DataLoader(
        dataset_train_aug,
        batch_size=args.batch_size,
        shuffle=True,
        drop_last=True,
        num_workers=n_workers,
        pin_memory=True,
    )
    data_loader_train_plain = DataLoader(
        dataset_train_plain,
        batch_size=args.batch_size_eval,
        drop_last=False,
        num_workers=n_workers,
    )
    data_loader_val_plain = DataLoader(
        dataset_val_plain,
        batch_size=args.batch_size_eval,
        drop_last=False,
        num_workers=n_workers,
    )
    data_loader_val_plain_subset = DataLoader(
        dataset_val_plain,
        batch_size=args.batch_size_eval,
        drop_last=False,
        sampler=SubsetRandomSampler(list(range(0, len(dataset_val_plain),
                                               50))),
        num_workers=n_workers,
    )

    # Logging
    writer = SummaryWriter(logging_path)
    writer.add_text("arguments", json.dumps(vars(args)))

    # Neural network related
    student_vit = timm.create_model(vit_name, pretrained=args.pretrained)
    teacher_vit = timm.create_model(vit_name, pretrained=args.pretrained)

    student = MultiCropWrapper(
        student_vit,
        Head(
            dim,
            args.out_dim,
            norm_last_layer=args.norm_last_layer,
        ),
    )
    teacher = MultiCropWrapper(teacher_vit, Head(dim, args.out_dim))
    student, teacher = student.to(device), teacher.to(device)

    teacher.load_state_dict(student.state_dict())

    for p in teacher.parameters():
        p.requires_grad = False

    # Loss related
    loss_inst = Loss(
        args.out_dim,
        teacher_temp=args.teacher_temp,
        student_temp=args.student_temp,
    ).to(device)
    lr = 0.0005 * args.batch_size / 256
    optimizer = torch.optim.AdamW(
        student.parameters(),
        lr=lr,
        weight_decay=args.weight_decay,
    )

    # Training loop
    n_batches = len(dataset_train_aug) // args.batch_size
    best_acc = 0
    n_steps = 0

    for e in range(args.n_epochs):
        for i, (images, _) in tqdm.tqdm(enumerate(data_loader_train_aug),
                                        total=n_batches):
            if n_steps % args.logging_freq == 0:
                student.eval()

                # Embedding
                embs, imgs, labels_ = compute_embedding(
                    student.backbone,
                    data_loader_val_plain_subset,
                )
                writer.add_embedding(
                    embs,
                    metadata=[label_mapping[l] for l in labels_],
                    label_img=imgs,
                    global_step=n_steps,
                    tag="embeddings",
                )

                # KNN
                current_acc = compute_knn(
                    student.backbone,
                    data_loader_train_plain,
                    data_loader_val_plain,
                )
                writer.add_scalar("knn-accuracy", current_acc, n_steps)
                if current_acc > best_acc:
                    torch.save(student, logging_path / "best_model.pth")
                    best_acc = current_acc

                student.train()

            images = [img.to(device) for img in images]

            teacher_output = teacher(images[:2])
            student_output = student(images)

            loss = loss_inst(student_output, teacher_output)

            optimizer.zero_grad()
            loss.backward()
            clip_gradients(student, args.clip_grad)
            optimizer.step()

            with torch.no_grad():
                for student_ps, teacher_ps in zip(student.parameters(),
                                                  teacher.parameters()):
                    teacher_ps.data.mul_(args.momentum_teacher)
                    teacher_ps.data.add_(
                        (1 - args.momentum_teacher) * student_ps.detach().data)

            writer.add_scalar("train_loss", loss, n_steps)

            n_steps += 1
Ejemplo n.º 10
0
def main(args):
    if args.model_name is not None:
        print('Preparing to train model: {}'.format(args.model_name))

    global device
    device = torch.device(
        'cuda' if torch.cuda.is_available() and not args.cpu else 'cpu')

    sc_will_happen = args.self_critical_from_epoch != -1

    if args.validate is None and args.lr_scheduler == 'ReduceLROnPlateau':
        print(
            'ERROR: you need to enable validation in order to use default lr_scheduler (ReduceLROnPlateau)'
        )
        print('Hint: use something like --validate=coco:val2017')
        sys.exit(1)

    # Create model directory
    if not os.path.exists(args.model_path):
        os.makedirs(args.model_path)

    # Image preprocessing, normalization for the pretrained resnet
    transform = transforms.Compose([
        # transforms.Resize((256, 256)),
        transforms.RandomCrop(args.crop_size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])

    scorers = {}
    if args.validation_scoring is not None or sc_will_happen:
        assert not (
            args.validation_scoring is None and sc_will_happen
        ), "Please provide a metric when using self-critical training"
        for s in args.validation_scoring.split(','):
            s = s.lower().strip()
            if s == 'cider':
                from eval.cider import Cider
                scorers['CIDEr'] = Cider()
            if s == 'ciderd':
                from eval.ciderD.ciderD import CiderD
                scorers['CIDEr-D'] = CiderD(df=args.cached_words)

    ########################
    # Set Model parameters #
    ########################

    # Store parameters gotten from arguments separately:
    arg_params = ModelParams.fromargs(args)

    print("Model parameters inferred from command arguments: ")
    print(arg_params)
    start_epoch = 0

    ###############################
    # Load existing model state   #
    # and update Model parameters #
    ###############################

    state = None

    if args.load_model:
        try:
            state = torch.load(args.load_model, map_location=device)
        except AttributeError:
            print(
                'WARNING: Old model found. Please use model_update.py in the model before executing this script.'
            )
            exit(1)
        new_external_features = arg_params.features.external

        params = ModelParams(state, arg_params=arg_params)
        if len(new_external_features
               ) and params.features.external != new_external_features:
            print('WARNING: external features changed: ',
                  params.features.external, new_external_features)
            print('Updating feature paths...')
            params.update_ext_features(new_external_features)
        start_epoch = state['epoch']
        print('Loaded model {} at epoch {}'.format(args.load_model,
                                                   start_epoch))
    else:
        params = arg_params
        params.command_history = []

    if params.rnn_hidden_init == 'from_features' and params.skip_start_token:
        print(
            "ERROR: Please remove --skip_start_token if you want to use image features "
            " to initialize hidden and cell states. <start> token is needed to trigger "
            " the process of sequence generation, since we don't have image features "
            " embedding as the first input token.")
        sys.exit(1)

    # Force set the following hierarchical model parameters every time:
    if arg_params.hierarchical_model:
        params.hierarchical_model = True
        params.max_sentences = arg_params.max_sentences
        params.weight_sentence_loss = arg_params.weight_sentence_loss
        params.weight_word_loss = arg_params.weight_word_loss
        params.dropout_stopping = arg_params.dropout_stopping
        params.dropout_fc = arg_params.dropout_fc
        params.coherent_sentences = arg_params.coherent_sentences
        params.coupling_alpha = arg_params.coupling_alpha
        params.coupling_beta = arg_params.coupling_beta

    assert args.replace or \
        not os.path.isdir(os.path.join(args.output_root, args.model_path, get_model_name(args, params))) or \
        not (args.load_model and not args.validate_only), \
        '{} already exists. If you want to replace it or resume training please use --replace flag. ' \
        'If you want to validate a loaded model without training it, use --validate_only flag.'  \
        'Otherwise specify a different model name using --model_name flag.'\
        .format(os.path.join(args.output_root, args.model_path, get_model_name(args, params)))

    if args.load_model:
        print("Final model parameters (loaded model + command arguments): ")
        print(params)

    ##############################
    # Load dataset configuration #
    ##############################

    dataset_configs = DatasetParams(args.dataset_config_file)

    if args.dataset is None and not args.validate_only:
        print('ERROR: No dataset selected!')
        print(
            'Please supply a training dataset with the argument --dataset DATASET'
        )
        print('The following datasets are configured in {}:'.format(
            args.dataset_config_file))
        for ds, _ in dataset_configs.config.items():
            if ds not in ('DEFAULT', 'generic'):
                print(' ', ds)
        sys.exit(1)

    if args.validate_only:
        if args.load_model is None:
            print(
                'ERROR: for --validate_only you need to specify a model to evaluate using --load_model MODEL'
            )
            sys.exit(1)
    else:
        dataset_params = dataset_configs.get_params(args.dataset)

        for i in dataset_params:
            i.config_dict['no_tokenize'] = args.no_tokenize
            i.config_dict['show_tokens'] = args.show_tokens
            i.config_dict['skip_start_token'] = params.skip_start_token

            if params.hierarchical_model:
                i.config_dict['hierarchical_model'] = True
                i.config_dict['max_sentences'] = params.max_sentences
                i.config_dict['crop_regions'] = False

    if args.validate is not None:
        validation_dataset_params = dataset_configs.get_params(args.validate)
        for i in validation_dataset_params:
            i.config_dict['no_tokenize'] = args.no_tokenize
            i.config_dict['show_tokens'] = args.show_tokens
            i.config_dict['skip_start_token'] = params.skip_start_token

            if params.hierarchical_model:
                i.config_dict['hierarchical_model'] = True
                i.config_dict['max_sentences'] = params.max_sentences
                i.config_dict['crop_regions'] = False

    #######################
    # Load the vocabulary #
    #######################

    # For pre-trained models attempt to obtain
    # saved vocabulary from the model itself:
    if args.load_model and params.vocab is not None:
        print("Loading vocabulary from the model file:")
        vocab = params.vocab
    else:
        if args.vocab is None:
            print(
                "ERROR: You must specify the vocabulary to be used for training using "
                "--vocab flag.\nTry --vocab AUTO if you want the vocabulary to be "
                "either generated from the training dataset or loaded from cache."
            )
            sys.exit(1)
        print("Loading / generating vocabulary:")
        vocab = get_vocab(args, dataset_params)

    print('Size of the vocabulary is {}'.format(len(vocab)))

    ##########################
    # Initialize data loader #
    ##########################

    ext_feature_sets = [
        params.features.external, params.persist_features.external
    ]
    if not args.validate_only:
        print('Loading dataset: {} with {} workers'.format(
            args.dataset, args.num_workers))
        if params.skip_start_token:
            print("Skipping the use of <start> token...")
        data_loader, ef_dims = get_loader(
            dataset_params,
            vocab,
            transform,
            args.batch_size,
            shuffle=True,
            num_workers=args.num_workers,
            ext_feature_sets=ext_feature_sets,
            skip_images=not params.has_internal_features(),
            verbose=args.verbose,
            unique_ids=sc_will_happen)
        if sc_will_happen:
            gts_sc = get_ground_truth_captions(data_loader.dataset)

    gts_sc_valid = None
    if args.validate is not None:
        valid_loader, ef_dims = get_loader(
            validation_dataset_params,
            vocab,
            transform,
            args.batch_size,
            shuffle=True,
            num_workers=args.num_workers,
            ext_feature_sets=ext_feature_sets,
            skip_images=not params.has_internal_features(),
            verbose=args.verbose)
        gts_sc_valid = get_ground_truth_captions(
            valid_loader.dataset) if sc_will_happen else None

    #########################################
    # Setup (optional) TensorBoardX logging #
    #########################################

    writer = None
    if args.tensorboard:
        if SummaryWriter is not None:
            model_name = get_model_name(args, params)
            timestamp = datetime.now().strftime('%Y%m%d%H%M%S')
            log_dir = os.path.join(
                args.output_root, 'log_tb/{}_{}'.format(model_name, timestamp))
            writer = SummaryWriter(log_dir=log_dir)
            print("INFO: Logging TensorBoardX events to {}".format(log_dir))
        else:
            print(
                "WARNING: SummaryWriter object not available. "
                "Hint: Please install TensorBoardX using pip install tensorboardx"
            )

    ######################
    # Build the model(s) #
    ######################

    # Set per parameter learning rate here, if supplied by the user:

    if args.lr_word_decoder is not None:
        if not params.hierarchical_model:
            print(
                "ERROR: Setting word decoder learning rate currently supported in Hierarchical Model only."
            )
            sys.exit(1)

        lr_dict = {'word_decoder': args.lr_word_decoder}
    else:
        lr_dict = {}

    model = EncoderDecoder(params,
                           device,
                           len(vocab),
                           state,
                           ef_dims,
                           lr_dict=lr_dict)

    ######################
    # Optimizer and loss #
    ######################

    sc_activated = False
    opt_params = model.get_opt_params()

    # Loss and optimizer
    if params.hierarchical_model:
        criterion = HierarchicalXEntropyLoss(
            weight_sentence_loss=params.weight_sentence_loss,
            weight_word_loss=params.weight_word_loss)
    elif args.share_embedding_weights:
        criterion = SharedEmbeddingXentropyLoss(param_lambda=0.15)
    else:
        criterion = nn.CrossEntropyLoss()

    if sc_will_happen:  # save it for later
        if args.self_critical_loss == 'sc':
            from model.loss import SelfCriticalLoss
            rl_criterion = SelfCriticalLoss()
        elif args.self_critical_loss == 'sc_with_diversity':
            from model.loss import SelfCriticalWithDiversityLoss
            rl_criterion = SelfCriticalWithDiversityLoss()
        elif args.self_critical_loss == 'sc_with_relative_diversity':
            from model.loss import SelfCriticalWithRelativeDiversityLoss
            rl_criterion = SelfCriticalWithRelativeDiversityLoss()
        elif args.self_critical_loss == 'sc_with_bleu_diversity':
            from model.loss import SelfCriticalWithBLEUDiversityLoss
            rl_criterion = SelfCriticalWithBLEUDiversityLoss()
        elif args.self_critical_loss == 'sc_with_repetition':
            from model.loss import SelfCriticalWithRepetitionLoss
            rl_criterion = SelfCriticalWithRepetitionLoss()
        elif args.self_critical_loss == 'mixed':
            from model.loss import MixedLoss
            rl_criterion = MixedLoss()
        elif args.self_critical_loss == 'mixed_with_face':
            from model.loss import MixedWithFACELoss
            rl_criterion = MixedWithFACELoss(vocab_size=len(vocab))
        elif args.self_critical_loss in [
                'sc_with_penalty', 'sc_with_penalty_throughout',
                'sc_masked_tokens'
        ]:
            raise ValueError('Deprecated loss, use \'sc\' loss')
        else:
            raise ValueError('Invalid self-critical loss')

        print('Selected self-critical loss is', rl_criterion)

        if start_epoch >= args.self_critical_from_epoch:
            criterion = rl_criterion
            sc_activated = True
            print('Self-critical loss training begins')

    # When using CyclicalLR, default learning rate should be always 1.0
    if args.lr_scheduler == 'CyclicalLR':
        default_lr = 1.
    else:
        default_lr = 0.001

    if sc_activated:
        optimizer = torch.optim.Adam(
            opt_params,
            lr=args.learning_rate if args.learning_rate else 5e-5,
            weight_decay=args.weight_decay)
    elif args.optimizer == 'adam':
        optimizer = torch.optim.Adam(opt_params,
                                     lr=default_lr,
                                     weight_decay=args.weight_decay)
    elif args.optimizer == 'rmsprop':
        optimizer = torch.optim.RMSprop(opt_params,
                                        lr=default_lr,
                                        weight_decay=args.weight_decay)
    elif args.optimizer == 'sgd':
        optimizer = torch.optim.SGD(opt_params,
                                    lr=default_lr,
                                    weight_decay=args.weight_decay)
    else:
        print('ERROR: unknown optimizer:', args.optimizer)
        sys.exit(1)

    # We don't want to initialize the optimizer if we are transfering
    # the language model from the regular model to hierarchical model
    transfer_language_model = False

    if arg_params.hierarchical_model and state and not state.get(
            'hierarchical_model'):
        transfer_language_model = True

    # Set optimizer state to the one found in a loaded model, unless
    # we are doing a transfer learning step from flat to hierarchical model,
    # or we are using self-critical loss,
    # or the number of unique parameter groups has changed, or the user
    # has explicitly told us *not to* reuse optimizer parameters from before
    if state and not transfer_language_model and not sc_activated and not args.optimizer_reset:
        # Check that number of parameter groups is the same
        if len(optimizer.param_groups) == len(
                state['optimizer']['param_groups']):
            optimizer.load_state_dict(state['optimizer'])

    # override lr if set explicitly in arguments -
    # 1) Global learning rate:
    if args.learning_rate:
        for param_group in optimizer.param_groups:
            param_group['lr'] = args.learning_rate
        params.learning_rate = args.learning_rate
    else:
        params.learning_rate = default_lr

    # 2) Parameter-group specific learning rate:
    if args.lr_word_decoder is not None:
        # We want to give user an option to set learning rate for word_decoder
        # separately. Other exceptions can be added as needed:
        for param_group in optimizer.param_groups:
            if param_group.get('name') == 'word_decoder':
                param_group['lr'] = args.lr_word_decoder
                break

    if args.validate is not None and args.lr_scheduler == 'ReduceLROnPlateau':
        print('Using ReduceLROnPlateau learning rate scheduler')
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                               'min',
                                                               verbose=True,
                                                               patience=2)
    elif args.lr_scheduler == 'StepLR':
        print('Using StepLR learning rate scheduler with step_size {}'.format(
            args.lr_step_size))
        # Decrease the learning rate by the factor of gamma at every
        # step_size epochs (for example every 5 or 10 epochs):
        step_size = args.lr_step_size
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                    step_size,
                                                    gamma=0.5,
                                                    last_epoch=-1)
    elif args.lr_scheduler == 'CyclicalLR':
        print(
            "Using Cyclical learning rate scheduler, lr range: [{},{}]".format(
                args.lr_cyclical_min, args.lr_cyclical_max))

        step_size = len(data_loader)
        clr = cyclical_lr(step_size,
                          min_lr=args.lr_cyclical_min,
                          max_lr=args.lr_cyclical_max)
        n_groups = len(optimizer.param_groups)
        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,
                                                      [clr] * n_groups)
    elif args.lr_scheduler is not None:
        print('ERROR: Invalid learing rate scheduler specified: {}'.format(
            args.lr_scheduler))
        sys.exit(1)

    ###################
    # Train the model #
    ###################

    stats_postfix = None
    if args.validate_only:
        stats_postfix = args.validate
    if args.load_model:
        all_stats = init_stats(args, params, postfix=stats_postfix)
    else:
        all_stats = {}

    if args.force_epoch:
        start_epoch = args.force_epoch - 1

    if not args.validate_only:
        total_step = len(data_loader)
        print(
            'Start training with start_epoch={:d} num_epochs={:d} num_batches={:d} ...'
            .format(start_epoch, args.num_epochs, args.num_batches))

    if args.teacher_forcing != 'always':
        print('\t k: {}'.format(args.teacher_forcing_k))
        print('\t beta: {}'.format(args.teacher_forcing_beta))
    print('Optimizer:', optimizer)

    if args.validate_only:
        stats = {}
        teacher_p = 1.0
        if args.teacher_forcing != 'always':
            print(
                'WARNING: teacher_forcing!=always, not yet implemented for --validate_only mode'
            )

        epoch = start_epoch - 1
        if str(epoch +
               1) in all_stats.keys() and args.skip_existing_validations:
            print('WARNING: epoch {} already validated, skipping...'.format(
                epoch + 1))
            return

        val_loss = do_validate(model, valid_loader, criterion, scorers, vocab,
                               teacher_p, args, params, stats, epoch,
                               sc_activated, gts_sc_valid)
        all_stats[str(epoch + 1)] = stats
        save_stats(args, params, all_stats, postfix=stats_postfix)
    else:
        for epoch in range(start_epoch, args.num_epochs):
            stats = {}
            begin = datetime.now()

            total_loss = 0

            if params.hierarchical_model:
                total_loss_sent = 0
                total_loss_word = 0

            num_batches = 0
            vocab_counts = {
                'cnt': 0,
                'max': 0,
                'min': 9999,
                'sum': 0,
                'unk_cnt': 0,
                'unk_sum': 0
            }

            # If start self critical training
            if not sc_activated and sc_will_happen and epoch >= args.self_critical_from_epoch:
                if all_stats:
                    best_ep, best_cider = max(
                        [(ep, all_stats[ep]['validation_cider'])
                         for ep in all_stats],
                        key=lambda x: x[1])
                    print('Loading model from epoch', best_ep,
                          'which has the better score with', best_cider)
                    state = torch.load(
                        get_model_path(args, params, int(best_ep)))
                    model = EncoderDecoder(params,
                                           device,
                                           len(vocab),
                                           state,
                                           ef_dims,
                                           lr_dict=lr_dict)
                    opt_params = model.get_opt_params()

                optimizer = torch.optim.Adam(opt_params,
                                             lr=5e-5,
                                             weight_decay=args.weight_decay)
                criterion = rl_criterion
                print('Self-critical loss training begins')
                sc_activated = True

            for i, data in enumerate(data_loader):

                if params.hierarchical_model:
                    (images, captions, lengths, image_ids, features,
                     sorting_order, last_sentence_indicator) = data
                    sorting_order = sorting_order.to(device)
                else:
                    (images, captions, lengths, image_ids, features) = data

                if epoch == 0:
                    unk = vocab('<unk>')
                    for j in range(captions.shape[0]):
                        # Flatten the caption in case it's a paragraph
                        # this is harmless for regular captions too:
                        xl = captions[j, :].view(-1)
                        xw = xl > unk
                        xu = xl == unk
                        xwi = sum(xw).item()
                        xui = sum(xu).item()
                        vocab_counts['cnt'] += 1
                        vocab_counts['sum'] += xwi
                        vocab_counts['max'] = max(vocab_counts['max'], xwi)
                        vocab_counts['min'] = min(vocab_counts['min'], xwi)
                        vocab_counts['unk_cnt'] += xui > 0
                        vocab_counts['unk_sum'] += xui
                # Set mini-batch dataset
                images = images.to(device)
                captions = captions.to(device)

                # Remove <start> token from targets if we are initializing the RNN
                # hidden state from image features:
                if params.rnn_hidden_init == 'from_features' and not params.hierarchical_model:
                    # Subtract one from all lengths to match new target lengths:
                    lengths = [x - 1 if x > 0 else x for x in lengths]
                    targets = pack_padded_sequence(captions[:, 1:],
                                                   lengths,
                                                   batch_first=True)[0]
                else:
                    if params.hierarchical_model:
                        targets = prepare_hierarchical_targets(
                            last_sentence_indicator, args.max_sentences,
                            lengths, captions, device)
                    else:
                        targets = pack_padded_sequence(captions,
                                                       lengths,
                                                       batch_first=True)[0]
                        sorting_order = None

                init_features = features[0].to(device) if len(
                    features) > 0 and features[0] is not None else None
                persist_features = features[1].to(device) if len(
                    features) > 1 and features[1] is not None else None

                # Forward, backward and optimize
                # Calculate the probability whether to use teacher forcing or not:

                # Iterate over batches:
                iteration = (epoch - start_epoch) * len(data_loader) + i

                teacher_p = get_teacher_prob(args.teacher_forcing_k, iteration,
                                             args.teacher_forcing_beta)

                # Allow model to log values at the last batch of the epoch
                writer_data = None
                if writer and (i == len(data_loader) - 1
                               or i == args.num_batches - 1):
                    writer_data = {'writer': writer, 'epoch': epoch + 1}

                sample_len = captions.size(1) if args.self_critical_loss in [
                    'mixed', 'mixed_with_face'
                ] else 20
                if sc_activated:
                    sampled_seq, sampled_log_probs, outputs = model.sample(
                        images,
                        init_features,
                        persist_features,
                        max_seq_length=sample_len,
                        start_token_id=vocab('<start>'),
                        trigram_penalty_alpha=args.trigram_penalty_alpha,
                        stochastic_sampling=True,
                        output_logprobs=True,
                        output_outputs=True)
                    sampled_seq = model.decoder.alt_prob_to_tensor(
                        sampled_seq, device=device)
                else:
                    outputs = model(images,
                                    init_features,
                                    captions,
                                    lengths,
                                    persist_features,
                                    teacher_p,
                                    args.teacher_forcing,
                                    sorting_order,
                                    writer_data=writer_data)

                if args.share_embedding_weights:
                    # Weights of (HxH) projection matrix used for regularizing
                    # models that share embedding weights
                    projection = model.decoder.projection.weight
                    loss = criterion(projection, outputs, targets)
                elif sc_activated:
                    # get greedy decoding baseline
                    model.eval()
                    with torch.no_grad():
                        greedy_sampled_seq = model.sample(
                            images,
                            init_features,
                            persist_features,
                            max_seq_length=sample_len,
                            start_token_id=vocab('<start>'),
                            trigram_penalty_alpha=args.trigram_penalty_alpha,
                            stochastic_sampling=False)
                        greedy_sampled_seq = model.decoder.alt_prob_to_tensor(
                            greedy_sampled_seq, device=device)
                    model.train()

                    if args.self_critical_loss in [
                            'sc', 'sc_with_diversity',
                            'sc_with_relative_diversity',
                            'sc_with_bleu_diversity', 'sc_with_repetition'
                    ]:
                        loss, advantage = criterion(
                            sampled_seq,
                            sampled_log_probs,
                            greedy_sampled_seq, [gts_sc[i] for i in image_ids],
                            scorers,
                            vocab,
                            return_advantage=True)
                    elif args.self_critical_loss in ['mixed']:
                        loss, advantage = criterion(
                            sampled_seq,
                            sampled_log_probs,
                            outputs,
                            greedy_sampled_seq, [gts_sc[i] for i in image_ids],
                            scorers,
                            vocab,
                            targets,
                            lengths,
                            gamma_ml_rl=args.gamma_ml_rl,
                            return_advantage=True)
                    elif args.self_critical_loss in ['mixed_with_face']:
                        loss, advantage = criterion(
                            sampled_seq,
                            sampled_log_probs,
                            outputs,
                            greedy_sampled_seq, [gts_sc[i] for i in image_ids],
                            scorers,
                            vocab,
                            captions,
                            targets,
                            lengths,
                            gamma_ml_rl=args.gamma_ml_rl,
                            return_advantage=True)
                    else:
                        raise ValueError('Invalid self-critical loss')

                    if writer is not None and i % 100 == 0:
                        writer.add_scalar('training_loss', loss.item(),
                                          epoch * len(data_loader) + i)
                        writer.add_scalar('advantage', advantage,
                                          epoch * len(data_loader) + i)
                        writer.add_scalar('lr',
                                          optimizer.param_groups[0]['lr'],
                                          epoch * len(data_loader) + i)
                else:
                    loss = criterion(outputs, targets)

                model.zero_grad()
                loss.backward()

                # Clip gradients if desired:
                if args.grad_clip is not None:
                    # grad_norms = [x.grad.data.norm(2) for x in opt_params]
                    # batch_max_grad = np.max(grad_norms)
                    # if batch_max_grad > 10.0:
                    #     print('WARNING: gradient norms larger than 10.0')

                    # torch.nn.utils.clip_grad_norm_(decoder.parameters(), 0.1)
                    # torch.nn.utils.clip_grad_norm_(encoder.parameters(), 0.1)
                    clip_gradients(optimizer, args.grad_clip)

                # Update weights:
                optimizer.step()

                # CyclicalLR requires us to update LR at every minibatch:
                if args.lr_scheduler == 'CyclicalLR':
                    scheduler.step()

                total_loss += loss.item()

                num_batches += 1

                if params.hierarchical_model:
                    _, loss_sent, _, loss_word = criterion.item_terms()
                    total_loss_sent += float(loss_sent)
                    total_loss_word += float(loss_word)

                # Print log info
                if (i + 1) % args.log_step == 0:
                    print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}, '
                          'Perplexity: {:5.4f}'.format(epoch + 1,
                                                       args.num_epochs, i + 1,
                                                       total_step, loss.item(),
                                                       np.exp(loss.item())))
                    sys.stdout.flush()

                    if params.hierarchical_model:
                        weight_sent, loss_sent, weight_word, loss_word = criterion.item_terms(
                        )
                        print('Sentence Loss: {:.4f}, '
                              'Word Loss: {:.4f}'.format(
                                  float(loss_sent), float(loss_word)))
                        sys.stdout.flush()

                if i + 1 == args.num_batches:
                    break

            end = datetime.now()

            stats['training_loss'] = total_loss / num_batches

            if params.hierarchical_model:
                stats['loss_sentence'] = total_loss_sent / num_batches
                stats['loss_word'] = total_loss_word / num_batches

            print('Epoch {} duration: {}, average loss: {:.4f}'.format(
                epoch + 1, end - begin, stats['training_loss']))

            save_model(args, params, model.encoder, model.decoder, optimizer,
                       epoch, vocab)

            if epoch == 0:
                vocab_counts['avg'] = vocab_counts['sum'] / vocab_counts['cnt']
                vocab_counts['unk_cnt_per'] = 100 * vocab_counts[
                    'unk_cnt'] / vocab_counts['cnt']
                vocab_counts['unk_sum_per'] = 100 * vocab_counts[
                    'unk_sum'] / vocab_counts['sum']
                # print(vocab_counts)
                print((
                    'Training data contains {sum} words in {cnt} captions (avg. {avg:.1f} w/c)'
                    + ' with {unk_sum} <unk>s ({unk_sum_per:.1f}%)' +
                    ' in {unk_cnt} ({unk_cnt_per:.1f}%) captions').format(
                        **vocab_counts))

            ############################################
            # Validation loss and learning rate update #
            ############################################

            if args.validate is not None and (epoch +
                                              1) % args.validation_step == 0:
                val_loss = do_validate(model, valid_loader, criterion, scorers,
                                       vocab, teacher_p, args, params, stats,
                                       epoch, sc_activated, gts_sc_valid)

                if args.lr_scheduler == 'ReduceLROnPlateau':
                    scheduler.step(val_loss)
            elif args.lr_scheduler == 'StepLR':
                scheduler.step()

            all_stats[str(epoch + 1)] = stats
            save_stats(args, params, all_stats, writer=writer)

            if writer is not None:
                # Log model data to tensorboard
                log_model_data(params, model, epoch + 1, writer)

    if writer is not None:
        writer.close()
Ejemplo n.º 11
0
def main():
    parser = argparse.ArgumentParser(
        "DINO training CLI",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    parser.add_argument(
        "-m",
        "--model",
        type=str,
        default="vit_tiny",
        choices=["vit_tiny", "vit_small", "vit_base"],
    )
    parser.add_argument("-b", "--batch-size", type=int, default=32)
    parser.add_argument("-d", "--device", type=int, default=0)
    parser.add_argument("--gpu", action="store_true")
    parser.add_argument("-l", "--logging-freq", type=int, default=200)
    parser.add_argument("--momentum-teacher", type=int, default=0.9995)
    parser.add_argument("-c", "--n-crops", type=int, default=4)
    parser.add_argument("-e", "--n-epochs", type=int, default=100)
    parser.add_argument("-o", "--out-dim", type=int, default=1024)
    parser.add_argument("-t", "--tensorboard-dir", type=str, default="")
    parser.add_argument("--optimizer", type=str, default="AdamW")
    parser.add_argument("--clip-grad", type=float, default=2.0)
    parser.add_argument("--norm-last-layer", action="store_true")
    parser.add_argument("--batch-size-eval", type=int, default=64)
    parser.add_argument("--teacher-temp", type=float, default=0.04)
    parser.add_argument("--student-temp", type=float, default=0.1)
    parser.add_argument("--pretrained", action="store_true")
    parser.add_argument("-w", "--weight-decay", type=float, default=0.4)

    args = parser.parse_args()
    print(vars(args))
    # Parameters
    models = {
        "vit_tiny": [vit_tiny, 192],
        "vit_small": [vit_small, 384],
        "vit_base": [vit_base, 768],
    }
    path_dataset_train = pathlib.Path("data/imagenette2-320/train")
    path_dataset_val = pathlib.Path("data/imagenette2-320/val")
    path_labels = pathlib.Path("data/imagenette_labels.json")

    if args.gpu:
        torch.cuda.empty_cache()
        torch.cuda.set_device(args.device)
        device = torch.cuda.current_device()
        print(f"Current CUDA device: {device}")
    else:
        device = torch.device("cpu")
        print(f"Current device: {device}")

    n_workers = 4

    ##################
    # Data preparation
    ##################
    with path_labels.open("r") as f:
        label_mapping = json.load(f)

    transform_aug = DataAugmentation(size=224, n_local_crops=args.n_crops - 2)
    transform_plain = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        transforms.Resize((224, 224)),
    ])

    dataset_train_aug = ImageFolder(path_dataset_train,
                                    transform=transform_aug)
    dataset_train_plain = ImageFolder(path_dataset_train,
                                      transform=transform_plain)
    dataset_val_plain = ImageFolder(path_dataset_val,
                                    transform=transform_plain)

    if dataset_train_plain.classes != dataset_val_plain.classes:
        raise ValueError("Inconsistent classes")

    train_dataloader_aug = DataLoader(
        dataset_train_aug,
        batch_size=args.batch_size,
        shuffle=True,
        drop_last=True,
        num_workers=n_workers,
        pin_memory=True,
    )
    train_dataloader_plain = DataLoader(
        dataset_train_plain,
        batch_size=args.batch_size_eval,
        drop_last=False,
        num_workers=n_workers,
    )
    val_dataloader_plain = DataLoader(
        dataset_val_plain,
        batch_size=args.batch_size_eval,
        drop_last=False,
        num_workers=n_workers,
    )
    val_dataloader_plain_subset = DataLoader(
        dataset_val_plain,
        batch_size=args.batch_size_eval,
        drop_last=False,
        sampler=SubsetRandomSampler(list(range(0, len(dataset_val_plain),
                                               50))),
        num_workers=n_workers,
    )
    print(f"[INFO] Data loaded")

    #########
    # Logging
    #########
    run = neptune.init(project="beomus/dino-test")
    run["config/parameters"] = json.dumps(vars(args))
    writer = SummaryWriter(log_dir=args.tensorboard_dir)
    writer.add_text("arguments", json.dumps(vars(args)))
    logging_path = pathlib.Path(writer.log_dir)

    wandb.init(project="dino", entity="beomus")
    wandb.config.update(args)

    print(f"[INFO] Logging started")

    #######################
    # Models initialization
    #######################
    model_fn, dim = models[args.model]
    student_vit = model_fn()
    teacher_vit = model_fn()

    student = MultiCropWrapper(
        student_vit,
        MlpHead(in_dim=dim,
                out_dim=args.out_dim,
                norm_last_layer=args.norm_last_layer),
    )
    teacher = MultiCropWrapper(teacher_vit, MlpHead(dim, args.out_dim))
    student, teacher = student.to(device), teacher.to(device)

    teacher.load_state_dict(student.state_dict())

    for p in teacher.parameters():
        p.requires_grad = False

    print(f"[INFO]: Model initialized")

    ######
    # Loss
    ######
    loss_inst = Loss(
        out_dim=args.out_dim,
        teacher_temp=args.teacher_temp,
        student_temp=args.student_temp,
    ).to(device)
    lr = 0.0005 * args.batch_size / 256

    optimizer_kwargs = {
        "params": student.parameters(),
        "lr": lr,
        "weight_decay": args.weight_decay,
        "amsgrad": True,
    }
    if args.optimizer == "SGD":
        optimizer_kwargs["momentum"] = 0.9
        optimizer_kwargs.pop("amsgrad")
    optimizer = getattr(torch.optim, args.optimizer)(**optimizer_kwargs)

    # optimizer = torch.optim.AdamW(
    #     student.parameters(), lr=lr, weight_decay=args.weight_decay
    # )

    model_name = f"{type(student).__name__}"
    with open(f"{logging_path / model_name}_arch.txt", "w") as f:
        f.write(str(student))
    run[f"config/model/{model_name}_arch"].upload(
        f"{logging_path / model_name}_arch.txt")

    optimizer_name = f"{type(optimizer).__name__}"
    with open(f"{logging_path / optimizer_name}.txt", "w") as f:
        f.write(str(optimizer))
    run[f"config/{optimizer_name}"].upload(
        f"{logging_path / optimizer_name}.txt")

    ###############
    # Training loop
    ###############
    n_batches = len(dataset_train_aug) // args.batch_size
    n_steps, best_acc = 0, 0

    print(f"[INFO]: Training started")
    for epoch in range(args.n_epochs):
        for i, (images, _) in tqdm.tqdm(enumerate(train_dataloader_aug),
                                        total=n_batches):
            if n_steps % args.logging_freq == 0:
                student.eval()

                # embedding
                embs, imgs, labels_ = compute_embedding(
                    student.backbone, val_dataloader_plain_subset)
                writer.add_embedding(
                    embs,
                    metadata=[label_mapping[l] for l in labels_],
                    label_img=imgs,
                    global_step=n_steps,
                    tag="embeddings",
                )

                # KNN
                current_acc = compute_knn(student.backbone,
                                          train_dataloader_plain,
                                          val_dataloader_plain)
                writer.add_scalar("knn-accuracy", current_acc, n_steps)
                run["metrics/acc"].log(current_acc)
                wandb.log({"accuracy": current_acc})
                if current_acc > best_acc:
                    model_path = str(logging_path / "model_best.pth")
                    torch.save(student, model_path)
                    run["model_checkpoints/my_model"].upload(model_path)
                    best_acc = current_acc

                student.train()

            images = [img.to(device) for img in images]

            teacher_output = teacher(images[:2])
            student_output = student(images)

            loss = loss_inst(student_output, teacher_output)

            optimizer.zero_grad()
            loss.backward()
            clip_gradients(student, args.clip_grad)
            optimizer.step()

            with torch.no_grad():
                for student_ps, teacher_ps in zip(student.parameters(),
                                                  teacher.parameters()):
                    teacher_ps.data.mul_(args.momentum_teacher)
                    teacher_ps.data.add_(
                        (1 - args.momentum_teacher) * student_ps.detach().data)

            writer.add_scalar("train_loss", loss, n_steps)
            run["metrics/loss"].log(loss)
            wandb.log({"loss": loss})

            n_steps += 1

    print(f"[INFO]: Training ended")
    run.stop()