Пример #1
0
 def forward(self, input_signal, length):
     length.requires_grad_(False)
     if self.disable_casts:
         with amp.disable_casts():
             if input_signal.dim() == 2:
                 processed_signal = self.featurizer(
                     input_signal.to(torch.float), length)
                 processed_length = self.featurizer.get_seq_len(length)
     else:
         if input_signal.dim() == 2:
             processed_signal = self.featurizer(input_signal, length)
             processed_length = self.featurizer.get_seq_len(length)
     return processed_signal, processed_length
Пример #2
0
 def forward(self, logits_4D, labels_4D, do_rmi=True):
     # explicitly disable fp16 mode because torch.cholesky and
     # torch.inverse aren't supported by half
     logits_4D.float()
     labels_4D.float()
     if cfg.TRAIN.FP16:
         with amp.disable_casts():
             loss = self.forward_sigmoid(logits_4D,
                                         labels_4D,
                                         do_rmi=do_rmi)
     else:
         loss = self.forward_sigmoid(logits_4D, labels_4D, do_rmi=do_rmi)
     return loss
Пример #3
0
 def forward(self, x):
     input_signal, length = x
     length.requires_grad_(False)
     if self.optim_level not in [
             Optimization.nothing, Optimization.mxprO0, Optimization.mxprO3
     ]:
         with amp.disable_casts():
             processed_signal = self.featurizer(x)
             processed_length = self.featurizer.get_seq_len(length)
     else:
         processed_signal = self.featurizer(x)
         processed_length = self.featurizer.get_seq_len(length)
     return processed_signal, processed_length
Пример #4
0
    def eval():
        """Evaluates model on evaluation dataset
        """
        with torch.no_grad():
            _global_var_dict = {
                'EvalLoss': [],
                'predictions': [],
                'transcripts': [],
            }
            eval_dataloader = data_layer_eval.data_iterator
            for data in eval_dataloader:
                tensors = []
                for d in data:
                    if isinstance(d, torch.Tensor):
                        tensors.append(d.cuda())
                    else:
                        tensors.append(d)
                t_audio_signal_e, t_a_sig_length_e, t_transcript_e, t_transcript_len_e = tensors

                model.eval()
                if optim_level == 1:
                  with amp.disable_casts():
                      t_processed_signal_e, t_processed_sig_length_e = audio_preprocessor(t_audio_signal_e, t_a_sig_length_e) 
                else:
                  t_processed_signal_e, t_processed_sig_length_e = audio_preprocessor(t_audio_signal_e, t_a_sig_length_e)
                if jasper_encoder.use_conv_mask:
                    t_log_probs_e, t_encoded_len_e = model.forward((t_processed_signal_e, t_processed_sig_length_e))
                else:
                    t_log_probs_e = model.forward(t_processed_signal_e)
                t_loss_e = ctc_loss(log_probs=t_log_probs_e, targets=t_transcript_e, input_length=t_encoded_len_e, target_length=t_transcript_len_e)
                t_predictions_e = greedy_decoder(log_probs=t_log_probs_e)

                values_dict = dict(
                    loss=[t_loss_e],
                    predictions=[t_predictions_e],
                    transcript=[t_transcript_e],
                    transcript_length=[t_transcript_len_e]
                )
                process_evaluation_batch(values_dict, _global_var_dict, labels=labels)

            # final aggregation across all workers and minibatches) and logging of results
            wer, eloss = process_evaluation_epoch(_global_var_dict)

            print_once("==========>>>>>>Evaluation Loss: {0}\n".format(eloss))
            print_once("==========>>>>>>Evaluation WER: {0}\n".format(wer))
Пример #5
0
    def __call__(self, audio, audio_lens, optim_level=0):
        dtype = audio.dtype
        audio = audio.float()
        if optim_level == 1:
            with amp.disable_casts():
                feat, feat_lens = self.calculate_features(audio, audio_lens)
        else:
            feat, feat_lens = self.calculate_features(audio, audio_lens)

        feat = self.apply_padding(feat)

        if self.cutout_augment is not None:
            feat = self.cutout_augment(feat)

        if self.spec_augment is not None:
            feat = self.spec_augment(feat)

        feat = feat.to(dtype)
        return feat, feat_lens
Пример #6
0
    def forward(self,
                src,
                pos,
                reference_points,
                spatial_shapes,
                level_start_index,
                padding_mask=None):
        # self attention
        with amp.disable_casts():
            src2 = self.self_attn(self.with_pos_embed(src, pos),
                                  reference_points, src, spatial_shapes,
                                  level_start_index, padding_mask)
        src = src + self.dropout1(src2)
        src = self.norm1(src)

        # ffn
        src = self.forward_ffn(src)

        return src
def main():
    args = get_args()
    os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)

    cur_timestamp = str(datetime.now())[:-3]  # we also include ms to prevent the probability of name collision
    model_width = {'linear': '', 'cnn': args.n_filters_cnn, 'lenet': '', 'resnet18': ''}[args.model]
    model_str = '{}{}'.format(args.model, model_width)
    model_name = '{} dataset={} model={} eps={} attack={} m={} attack_init={} fgsm_alpha={} epochs={} pgd={}-{} grad_align_cos_lambda={} lr_max={} seed={}'.format(
        cur_timestamp, args.dataset, model_str, args.eps, args.attack, args.minibatch_replay, args.attack_init, args.fgsm_alpha, args.epochs,
        args.pgd_alpha_train, args.pgd_train_n_iters, args.grad_align_cos_lambda, args.lr_max, args.seed)
    if not os.path.exists('models'):
        os.makedirs('models')
    logger = utils.configure_logger(model_name, args.debug)
    logger.info(args)
    half_prec = args.half_prec
    n_cls = 2 if 'binary' in args.dataset else 10

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    double_bp = True if args.grad_align_cos_lambda > 0 else False
    n_eval_every_k_iter = args.n_eval_every_k_iter
    args.pgd_alpha = args.eps / 4

    eps, pgd_alpha, pgd_alpha_train = args.eps / 255, args.pgd_alpha / 255, args.pgd_alpha_train / 255
    train_data_augm = False if args.dataset in ['mnist'] else True
    train_batches = data.get_loaders(args.dataset, -1, args.batch_size, train_set=True, shuffle=True, data_augm=train_data_augm)
    train_batches_fast = data.get_loaders(args.dataset, n_eval_every_k_iter, args.batch_size, train_set=True, shuffle=False, data_augm=False)
    test_batches = data.get_loaders(args.dataset, args.n_final_eval, args.batch_size_eval, train_set=False, shuffle=False, data_augm=False)
    test_batches_fast = data.get_loaders(args.dataset, n_eval_every_k_iter, args.batch_size_eval, train_set=False, shuffle=False, data_augm=False)

    model = models.get_model(args.model, n_cls, half_prec, data.shapes_dict[args.dataset], args.n_filters_cnn).cuda()
    model.apply(utils.initialize_weights)
    model.train()

    if args.model == 'resnet18':
        opt = torch.optim.SGD(model.parameters(), lr=args.lr_max, momentum=0.9, weight_decay=args.weight_decay)
    elif args.model == 'cnn':
        opt = torch.optim.Adam(model.parameters(), lr=args.lr_max, weight_decay=args.weight_decay)
    elif args.model == 'lenet':
        opt = torch.optim.Adam(model.parameters(), lr=args.lr_max, weight_decay=args.weight_decay)
    else:
        raise ValueError('decide about the right optimizer for the new model')

    if half_prec:
        if double_bp:
            amp.register_float_function(torch, 'batch_norm')
        model, opt = amp.initialize(model, opt, opt_level="O1")

    if args.attack == 'fgsm':  # needed here only for Free-AT
        delta = torch.zeros(args.batch_size, *data.shapes_dict[args.dataset][1:]).cuda()
        delta.requires_grad = True

    lr_schedule = utils.get_lr_schedule(args.lr_schedule, args.epochs, args.lr_max)
    loss_function = nn.CrossEntropyLoss()

    train_acc_pgd_best, best_state_dict = 0.0, copy.deepcopy(model.state_dict())
    start_time = time.time()
    time_train, iteration, best_iteration = 0, 0, 0
    for epoch in range(args.epochs + 1):
        train_loss, train_reg, train_acc, train_n, grad_norm_x, avg_delta_l2 = 0, 0, 0, 0, 0, 0
        for i, (X, y) in enumerate(train_batches):
            if i % args.minibatch_replay != 0 and i > 0:  # take new inputs only each `minibatch_replay` iterations
                X, y = X_prev, y_prev
            time_start_iter = time.time()
            # epoch=0 runs only for one iteration (to check the training stats at init)
            if epoch == 0 and i > 0:
                break
            X, y = X.cuda(), y.cuda()
            lr = lr_schedule(epoch - 1 + (i + 1) / len(train_batches))  # epoch - 1 since the 0th epoch is skipped
            opt.param_groups[0].update(lr=lr)

            if args.attack in ['pgd', 'pgd_corner']:
                pgd_rs = True if args.attack_init == 'random' else False
                n_eps_warmup_epochs = 5
                n_iterations_max_eps = n_eps_warmup_epochs * data.shapes_dict[args.dataset][0] // args.batch_size
                eps_pgd_train = min(iteration / n_iterations_max_eps * eps, eps) if args.dataset == 'svhn' else eps
                delta = utils.attack_pgd_training(
                    model, X, y, eps_pgd_train, pgd_alpha_train, opt, half_prec, args.pgd_train_n_iters, rs=pgd_rs)
                if args.attack == 'pgd_corner':
                    delta = eps * utils.sign(delta)  # project to the corners
                    delta = clamp(X + delta, 0, 1) - X

            elif args.attack == 'fgsm':
                if args.minibatch_replay == 1:
                    if args.attack_init == 'zero':
                        delta = torch.zeros_like(X, requires_grad=True)
                    elif args.attack_init == 'random':
                        delta = utils.get_uniform_delta(X.shape, eps, requires_grad=True)
                    else:
                        raise ValueError('wrong args.attack_init')
                else:  # if Free-AT, we just reuse the existing delta from the previous iteration
                    delta.requires_grad = True

                X_adv = clamp(X + delta, 0, 1)
                output = model(X_adv)
                loss = F.cross_entropy(output, y)
                if half_prec:
                    with amp.scale_loss(loss, opt) as scaled_loss:
                        grad = torch.autograd.grad(scaled_loss, delta, create_graph=True if double_bp else False)[0]
                        grad /= scaled_loss / loss  # reverse back the scaling
                else:
                    grad = torch.autograd.grad(loss, delta, create_graph=True if double_bp else False)[0]

                grad = grad.detach()

                argmax_delta = eps * utils.sign(grad)

                n_alpha_warmup_epochs = 5
                n_iterations_max_alpha = n_alpha_warmup_epochs * data.shapes_dict[args.dataset][0] // args.batch_size
                fgsm_alpha = min(iteration / n_iterations_max_alpha * args.fgsm_alpha, args.fgsm_alpha) if args.dataset == 'svhn' else args.fgsm_alpha
                delta.data = clamp(delta.data + fgsm_alpha * argmax_delta, -eps, eps)
                delta.data = clamp(X + delta.data, 0, 1) - X

            elif args.attack == 'random_corner':
                delta = utils.get_uniform_delta(X.shape, eps, requires_grad=False)
                delta = eps * utils.sign(delta)

            elif args.attack == 'none':
                delta = torch.zeros_like(X, requires_grad=False)
            else:
                raise ValueError('wrong args.attack')

            # extra FP+BP to calculate the gradient to monitor it
            if args.attack in ['none', 'random_corner', 'pgd', 'pgd_corner']:
                grad = get_input_grad(model, X, y, opt, eps, half_prec, delta_init='none',
                                      backprop=args.grad_align_cos_lambda != 0.0)

            delta = delta.detach()

            output = model(X + delta)
            loss = loss_function(output, y)

            reg = torch.zeros(1).cuda()[0]  # for .item() to run correctly
            if args.grad_align_cos_lambda != 0.0:
                grad2 = get_input_grad(model, X, y, opt, eps, half_prec, delta_init='random_uniform', backprop=True)
                grads_nnz_idx = ((grad**2).sum([1, 2, 3])**0.5 != 0) * ((grad2**2).sum([1, 2, 3])**0.5 != 0)
                grad1, grad2 = grad[grads_nnz_idx], grad2[grads_nnz_idx]
                grad1_norms, grad2_norms = l2_norm_batch(grad1), l2_norm_batch(grad2)
                grad1_normalized = grad1 / grad1_norms[:, None, None, None]
                grad2_normalized = grad2 / grad2_norms[:, None, None, None]
                cos = torch.sum(grad1_normalized * grad2_normalized, (1, 2, 3))
                reg += args.grad_align_cos_lambda * (1.0 - cos.mean())

            loss += reg

            if epoch != 0:
                opt.zero_grad()
                utils.backward(loss, opt, half_prec)
                opt.step()

            time_train += time.time() - time_start_iter
            train_loss += loss.item() * y.size(0)
            train_reg += reg.item() * y.size(0)
            train_acc += (output.max(1)[1] == y).sum().item()
            train_n += y.size(0)

            with torch.no_grad():  # no grad for the stats
                grad_norm_x += l2_norm_batch(grad).sum().item()
                delta_final = clamp(X + delta, 0, 1) - X  # we should measure delta after the projection onto [0, 1]^d
                avg_delta_l2 += ((delta_final ** 2).sum([1, 2, 3]) ** 0.5).sum().item()

            if iteration % args.eval_iter_freq == 0:
                train_loss, train_reg = train_loss / train_n, train_reg / train_n
                train_acc, avg_delta_l2 = train_acc / train_n, avg_delta_l2 / train_n

                # it'd be incorrect to recalculate the BN stats on the test sets and for clean / adversarial points
                utils.model_eval(model, half_prec)

                test_acc_clean, _, _ = rob_acc(test_batches_fast, model, eps, pgd_alpha, opt, half_prec, 0, 1)
                test_acc_fgsm, test_loss_fgsm, fgsm_deltas = rob_acc(test_batches_fast, model, eps, eps, opt, half_prec, 1, 1, rs=False)
                test_acc_pgd, test_loss_pgd, pgd_deltas = rob_acc(test_batches_fast, model, eps, pgd_alpha, opt, half_prec, args.attack_iters, 1)
                cos_fgsm_pgd = utils.avg_cos_np(fgsm_deltas, pgd_deltas)
                train_acc_pgd, _, _ = rob_acc(train_batches_fast, model, eps, pgd_alpha, opt, half_prec, args.attack_iters, 1)  # needed for early stopping

                grad_x = utils.get_grad_np(model, test_batches_fast, eps, opt, half_prec, rs=False)
                grad_eta = utils.get_grad_np(model, test_batches_fast, eps, opt, half_prec, rs=True)
                cos_x_eta = utils.avg_cos_np(grad_x, grad_eta)

                time_elapsed = time.time() - start_time
                train_str = '[train] loss {:.3f}, reg {:.3f}, acc {:.2%} acc_pgd {:.2%}'.format(train_loss, train_reg, train_acc, train_acc_pgd)
                test_str = '[test] acc_clean {:.2%}, acc_fgsm {:.2%}, acc_pgd {:.2%}, cos_x_eta {:.3}, cos_fgsm_pgd {:.3}'.format(
                    test_acc_clean, test_acc_fgsm, test_acc_pgd, cos_x_eta, cos_fgsm_pgd)
                logger.info('{}-{}: {}  {} ({:.2f}m, {:.2f}m)'.format(epoch, iteration, train_str, test_str,
                                                                      time_train/60, time_elapsed/60))

                if train_acc_pgd > train_acc_pgd_best:  # catastrophic overfitting can be detected on the training set
                    best_state_dict = copy.deepcopy(model.state_dict())
                    train_acc_pgd_best, best_iteration = train_acc_pgd, iteration

                utils.model_train(model, half_prec)
                train_loss, train_reg, train_acc, train_n, grad_norm_x, avg_delta_l2 = 0, 0, 0, 0, 0, 0

            iteration += 1
            X_prev, y_prev = X.clone(), y.clone()  # needed for Free-AT

        if epoch == args.epochs:
            torch.save({'last': model.state_dict(), 'best': best_state_dict}, 'models/{} epoch={}.pth'.format(model_name, epoch))
            # disable global conversion to fp16 from amp.initialize() (https://github.com/NVIDIA/apex/issues/567)
            context_manager = amp.disable_casts() if half_prec else utils.nullcontext()
            with context_manager:
                last_state_dict = copy.deepcopy(model.state_dict())
                half_prec = False  # final eval is always in fp32
                model.load_state_dict(last_state_dict)
                utils.model_eval(model, half_prec)
                opt = torch.optim.SGD(model.parameters(), lr=0)

                attack_iters, n_restarts = (50, 10) if not args.debug else (10, 3)
                test_acc_clean, _, _ = rob_acc(test_batches, model, eps, pgd_alpha, opt, half_prec, 0, 1)
                test_acc_pgd_rr, _, deltas_pgd_rr = rob_acc(test_batches, model, eps, pgd_alpha, opt, half_prec, attack_iters, n_restarts)
                logger.info('[last: test on 10k points] acc_clean {:.2%}, pgd_rr {:.2%}'.format(test_acc_clean, test_acc_pgd_rr))

                if args.eval_early_stopped_model:
                    model.load_state_dict(best_state_dict)
                    utils.model_eval(model, half_prec)
                    test_acc_clean, _, _ = rob_acc(test_batches, model, eps, pgd_alpha, opt, half_prec, 0, 1)
                    test_acc_pgd_rr, _, deltas_pgd_rr = rob_acc(test_batches, model, eps, pgd_alpha, opt, half_prec, attack_iters, n_restarts)
                    logger.info('[best: test on 10k points][iter={}] acc_clean {:.2%}, pgd_rr {:.2%}'.format(
                        best_iteration, test_acc_clean, test_acc_pgd_rr))

        utils.model_train(model, half_prec)

    logger.info('Done in {:.2f}m'.format((time.time() - start_time) / 60))
Пример #8
0
            else:
                if (rolling_average > record_rolling_average):
                    # Save model with a munged filename; e.g. doodles_706.pth
                    if (SAVE_BACKUP_FILES):
                        backupPth = NUMBERED_STATE_DICT_FILE_TEMPLATE.format(rolling_average, BATCH_SIZE)
                        torch.save(model.state_dict(), backupPth)
                        print('Saved model file {}'.format(backupPth))
                        # Delete the last backup .pth file we wrote to avoid filling up the drive
                        if (record_rolling_average > 0):
                            old_file = NUMBERED_STATE_DICT_FILE_TEMPLATE.format(record_rolling_average, BATCH_SIZE)
                            if os.path.exists(old_file):
                                os.remove(old_file)
                        # Same for ONNX
                        backupOnnx = NUMBERED_ONNX_FILE_TEMPLATE.format(rolling_average, BATCH_SIZE)
                        if MIXED_PRECISION:
                            with amp.disable_casts():
                                dummy_input = torch.randn(1, 1, 64, 64).to(DEVICE)
                                torch.onnx.export(model, dummy_input, backupOnnx, verbose=False)
                        else:
                            dummy_input = torch.randn(1, 1, 64, 64).to(DEVICE)
                            torch.onnx.export(model, dummy_input, backupOnnx, verbose=False)
                        print('Saved ONNX file {}'.format(backupOnnx))
                        # Delete the last backup ONNX file we wrote to avoid filling up the drive
                        if (record_rolling_average > 0):
                            old_file = NUMBERED_ONNX_FILE_TEMPLATE.format(record_rolling_average, BATCH_SIZE)
                            if os.path.exists(old_file):
                                os.remove(old_file)
                    record_rolling_average = rolling_average

            # Deleting the model file during training triggers a fresh rewrite:
            if (os.path.isfile(STATE_DICT_FILE) == False):
Пример #9
0
def train(data_layer,
          data_layer_eval,
          model,
          ema_model,
          ctc_loss,
          greedy_decoder,
          optimizer,
          optim_level,
          labels,
          multi_gpu,
          args,
          fn_lr_policy=None):
    """Trains model
    Args:
        data_layer: training data layer
        data_layer_eval: evaluation data layer
        model: model ( encapsulates data processing, encoder, decoder)
        ctc_loss: loss function
        greedy_decoder: greedy ctc decoder
        optimizer: optimizer
        optim_level: AMP optimization level
        labels: list of output labels
        multi_gpu: true if multi gpu training
        args: script input argument list
        fn_lr_policy: learning rate adjustment function
    """
    def eval(model, name=''):
        """Evaluates model on evaluation dataset
        """
        with torch.no_grad():
            _global_var_dict = {
                'EvalLoss': [],
                'predictions': [],
                'transcripts': [],
            }
            eval_dataloader = data_layer_eval.data_iterator
            for data in eval_dataloader:
                tensors = []
                for d in data:
                    if isinstance(d, torch.Tensor):
                        tensors.append(d.cuda())
                    else:
                        tensors.append(d)
                t_audio_signal_e, t_a_sig_length_e, t_transcript_e, t_transcript_len_e = tensors

                model.eval()
                if optim_level == 1:
                    with amp.disable_casts():
                        t_processed_signal_e, t_processed_sig_length_e = audio_preprocessor(
                            t_audio_signal_e, t_a_sig_length_e)
                else:
                    t_processed_signal_e, t_processed_sig_length_e = audio_preprocessor(
                        t_audio_signal_e, t_a_sig_length_e)
                if encoder.use_conv_mask:
                    t_log_probs_e, t_encoded_len_e = model.forward(
                        (t_processed_signal_e, t_processed_sig_length_e))
                else:
                    t_log_probs_e = model.forward(t_processed_signal_e)
                t_loss_e = ctc_loss(log_probs=t_log_probs_e,
                                    targets=t_transcript_e,
                                    input_length=t_encoded_len_e,
                                    target_length=t_transcript_len_e)
                t_predictions_e = greedy_decoder(log_probs=t_log_probs_e)

                values_dict = dict(loss=[t_loss_e],
                                   predictions=[t_predictions_e],
                                   transcript=[t_transcript_e],
                                   transcript_length=[t_transcript_len_e])
                process_evaluation_batch(values_dict,
                                         _global_var_dict,
                                         labels=labels)

            # final aggregation across all workers and minibatches) and logging of results
            wer, eloss = process_evaluation_epoch(_global_var_dict)

            if name != '':
                name = '_' + name

            print_once(f"==========>>>>>>Evaluation{name} Loss: {eloss}\n")
            print_once(
                f"==========>>>>>>Evaluation{name} WER: {round(wer, 2) * 100}%\n"
            )

    print_once("Starting .....")
    start_time = time.time()

    train_dataloader = data_layer.data_iterator
    epoch = args.start_epoch
    step = epoch * args.step_per_epoch

    audio_preprocessor = model.module.audio_preprocessor if hasattr(
        model, 'module') else model.audio_preprocessor
    data_spectr_augmentation = model.module.data_spectr_augmentation if hasattr(
        model, 'module') else model.data_spectr_augmentation
    encoder = model.module.encoder if hasattr(model,
                                              'module') else model.encoder

    while True:
        if multi_gpu:
            data_layer.sampler.set_epoch(epoch)
        print_once("Starting epoch {0}, step {1}".format(epoch, step))
        last_epoch_start = time.time()
        batch_counter = 0
        average_loss = 0
        for data in train_dataloader:
            tensors = []
            for d in data:
                if isinstance(d, torch.Tensor):
                    tensors.append(d.cuda())
                else:
                    tensors.append(d)

            if batch_counter == 0:

                if fn_lr_policy is not None:
                    adjusted_lr = fn_lr_policy(step)
                    for param_group in optimizer.param_groups:
                        param_group['lr'] = adjusted_lr
                optimizer.zero_grad()
                last_iter_start = time.time()

            t_audio_signal_t, t_a_sig_length_t, t_transcript_t, t_transcript_len_t = tensors
            model.train()
            if optim_level == 1:
                with amp.disable_casts():
                    t_processed_signal_t, t_processed_sig_length_t = audio_preprocessor(
                        t_audio_signal_t, t_a_sig_length_t)
            else:
                t_processed_signal_t, t_processed_sig_length_t = audio_preprocessor(
                    t_audio_signal_t, t_a_sig_length_t)
            t_processed_signal_t = data_spectr_augmentation(
                t_processed_signal_t)
            if encoder.use_conv_mask:
                t_log_probs_t, t_encoded_len_t = model.forward(
                    (t_processed_signal_t, t_processed_sig_length_t))
            else:
                t_log_probs_t = model.forward(t_processed_signal_t)

            t_loss_t = ctc_loss(log_probs=t_log_probs_t,
                                targets=t_transcript_t,
                                input_length=t_encoded_len_t,
                                target_length=t_transcript_len_t)
            if args.gradient_accumulation_steps > 1:
                t_loss_t = t_loss_t / args.gradient_accumulation_steps

            if 0 < optim_level <= 3:
                with amp.scale_loss(t_loss_t, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                t_loss_t.backward()
            batch_counter += 1
            average_loss += t_loss_t.item()

            if batch_counter % args.gradient_accumulation_steps == 0:
                optimizer.step()

                if step % args.train_frequency == 0:
                    t_predictions_t = greedy_decoder(log_probs=t_log_probs_t)

                    e_tensors = [
                        t_predictions_t, t_transcript_t, t_transcript_len_t
                    ]
                    train_wer = monitor_asr_train_progress(e_tensors,
                                                           labels=labels)
                    print_once("Loss@Step: {0}  ::::::: {1}".format(
                        step, str(round(average_loss, 2))))
                    print_once("Step time: {0} seconds".format(
                        round(time.time() - last_iter_start, 2)))
                if step > 0 and step % args.eval_frequency == 0:
                    print_once(
                        "Doing Evaluation ....................... ......  ... .. . ."
                    )
                    eval(model)
                    if args.ema > 0:
                        eval(ema_model, 'EMA')

                step += 1
                batch_counter = 0
                average_loss = 0
                if args.num_steps is not None and step >= args.num_steps:
                    break

        if args.num_steps is not None and step >= args.num_steps:
            break
        print_once("Finished epoch {0} in {1}".format(
            epoch,
            time.time() - last_epoch_start))
        epoch += 1
        if epoch % args.save_frequency == 0 and epoch > 0:
            save(model, ema_model, optimizer, epoch, args.output_dir,
                 optim_level)
        if args.num_steps is None and epoch >= args.num_epochs:
            break
    print_once("Done in {0}".format(time.time() - start_time))
    print_once("Final Evaluation ....................... ......  ... .. . .")
    eval(model)
    if args.ema > 0:
        eval(ema_model, 'EMA')
    save(model, ema_model, optimizer, epoch, args.output_dir, optim_level)
Пример #10
0
    def evaluate(self, mode):
        '''
			mode choose from <int> or best
			<int> is the number of epoch, represents the number of epoch, used for in training evaluation
			'best' is used for after training mode
		'''

        set_name = 'test'
        eval_model_list = ['depthEstModel']

        if isinstance(mode, int) and self.isTrain:
            self._set_models_eval(eval_model_list)
            if self.EVAL_best_loss == float('inf'):
                fn = open(self.evaluate_log, 'w')
            else:
                fn = open(self.evaluate_log, 'a')

            fn.write('Evaluating with mode: {}\n'.format(mode))
            fn.write('\tEvaluation range min: {} | max: {} \n'.format(
                self.EVAL_DEPTH_MIN, self.EVAL_DEPTH_MAX))
            fn.close()

        else:
            self._load_models(eval_model_list, mode)

        print('Evaluating with mode: {}'.format(mode))
        print('\tEvaluation range min: {} | max: {}'.format(
            self.EVAL_DEPTH_MIN, self.EVAL_DEPTH_MAX))

        total_loss, count = 0., 0
        predTensor = torch.zeros(
            (1, 1, self.cropSize_h, self.cropSize_w)).to('cpu')
        grndTensor = torch.zeros(
            (1, 1, self.cropSize_h, self.cropSize_w)).to('cpu')
        idx = 0

        tensorboardX_iter_count = 0
        for sample in self.dataloaders_single[set_name]:
            imageList, depthList = sample
            valid_mask = np.logical_and(depthList > self.EVAL_DEPTH_MIN,
                                        depthList < self.EVAL_DEPTH_MAX)

            idx += imageList.shape[0]
            print(
                'epoch {}: have processed {} number samples in {} set'.format(
                    mode, str(idx), set_name))
            imageList = imageList.to(self.device)
            depthList = depthList.to(self.device)  # real depth

            if self.isTrain and self.use_apex:
                with amp.disable_casts():
                    predList = self.depthEstModel(imageList)[-1].detach().to(
                        'cpu')
            else:
                predList = self.depthEstModel(imageList)[-1].detach().to('cpu')

            # recover real depth
            predList = (predList + 1.0) * 0.5 * self.NYU_MAX_DEPTH_CLIP
            depthList = depthList.detach().to('cpu')
            predTensor = torch.cat((predTensor, predList), dim=0)
            grndTensor = torch.cat((grndTensor, depthList), dim=0)

            if self.use_tensorboardX:
                nrow = imageList.size()[0]
                if tensorboardX_iter_count % self.val_display_freq == 0:
                    depth_concat = torch.cat((depthList, predList), dim=0)
                    self.write_2_tensorboardX(
                        self.eval_SummaryWriter,
                        depth_concat,
                        name='{}: ground truth and depth prediction'.format(
                            set_name),
                        mode='image',
                        count=tensorboardX_iter_count,
                        nrow=nrow,
                        value_range=(0.0, self.NYU_MAX_DEPTH_CLIP))

                tensorboardX_iter_count += 1

            if isinstance(mode, int) and self.isTrain:
                eval_depth_loss = self.L1loss(predList[valid_mask],
                                              depthList[valid_mask])
                total_loss += eval_depth_loss.detach().cpu()

            count += 1

        if isinstance(mode, int) and self.isTrain:
            validation_loss = (total_loss / count)
            print('validation loss is {:.7f}'.format(validation_loss))
            if self.use_tensorboardX:
                self.write_2_tensorboardX(self.eval_SummaryWriter,
                                          validation_loss,
                                          name='validation loss',
                                          mode='scalar',
                                          count=mode)

        results = Result(mask_min=self.EVAL_DEPTH_MIN,
                         mask_max=self.EVAL_DEPTH_MAX)
        results.evaluate(predTensor[1:], grndTensor[1:])

        result1 = '\tabs_rel:{:.3f}, sq_rel:{:.3f}, rmse:{:.3f}, rmse_log:{:.3f}, mae:{:.3f} '.format(
            results.absrel, results.sqrel, results.rmse, results.rmselog,
            results.mae)
        result2 = '\t[<1.25]:{:.3f}, [<1.25^2]:{:.3f}, [<1.25^3]::{:.3f}'.format(
            results.delta1, results.delta2, results.delta3)

        print(result1)
        print(result2)

        if isinstance(mode, int) and self.isTrain:
            self.EVAL_all_results[str(mode)] = result1 + '\t' + result2

            if validation_loss.item() < self.EVAL_best_loss:
                self.EVAL_best_loss = validation_loss.item()
                self.EVAL_best_model_epoch = mode
                self.save_models(self.model_name, mode='best')

            best_model_summary = '\tCurrent best loss {:.7f}, current best model {}\n'.format(
                self.EVAL_best_loss, self.EVAL_best_model_epoch)
            print(best_model_summary)

            fn = open(self.evaluate_log, 'a')
            fn.write(result1 + '\n')
            fn.write(result2 + '\n')
            fn.write(best_model_summary + '\n')
            fn.close()
Пример #11
0
	def evaluate(self, mode):
		'''
			mode choose from <int> or best
			<int> is the number of epoch, represents the number of epoch, used for in training evaluation
			'best' is used for after training mode
		'''

		set_name = 'test'
		eval_model_list = ['depthEstModel']

		if isinstance(mode, int) and self.is_train:
			self._set_models_eval(eval_model_list)
			if self.EVAL_best_loss == float('inf'):
				fn = open(self.evaluate_log, 'w')
			else:
				fn = open(self.evaluate_log, 'a')

			fn.write('Evaluating with mode: {} | dataset: {} \n'.format(mode, self.testing_set_name))
			fn.write('\tEvaluation range min: {} | max: {} \n'.format(self.EVAL_DEPTH_MIN, self.EVAL_DEPTH_MAX))
			fn.close()

		else:
			self._load_models(eval_model_list, mode)

		print('Evaluating with mode: {} | dataset: {}'.format(mode, self.testing_set_name))
		print('\tEvaluation range min: {} | max: {}'.format(self.EVAL_DEPTH_MIN, self.EVAL_DEPTH_MAX))

		total_loss = 0.
		count = 0

		predTensor = torch.zeros((1, 1, self.H, self.W)).to('cpu')
		grndTensor = torch.zeros((1, 1, self.H, self.W)).to('cpu')
		imgTensor = torch.zeros((1, 3, self.H, self.W)).to('cpu')
		extTensor = torch.zeros((1, 6)).to('cpu')
		idx = 0

		# tensorboardX_iter_count = 0
		with torch.no_grad():
			for sample_dict in self.testing_dataloader:
				imageTensor, depthGTTensor = sample_dict['rgb'], sample_dict['depth']
				extrinsic_para = sample_dict['extrinsic'].float() # otherwise mismatch data type double and float

				if "intrinsic" in sample_dict.keys():
					# for ScanNet only
					intrinsic_para = sample_dict['intrinsic'].float() # fx, fy, px, py
					focal_length = intrinsic_para[:, :2]
					p_pt = intrinsic_para[:, 2:]
				else:
					# for interiorNet
					focal_length = 300
					p_pt = (120, 160)

				extrinsic_channel = get_extrinsic_channel(imageTensor, focal_length, p_pt, extrinsic_para, self.CEILING_HEIGHT)
				imageTensor_C = torch.cat((imageTensor, extrinsic_channel), dim=1)
				valid_mask = np.logical_and(depthGTTensor >= self.EVAL_DEPTH_MIN, depthGTTensor <= self.EVAL_DEPTH_MAX)

				idx += imageTensor.shape[0]
				print('epoch {}: have processed {} number samples in {} set'.format(mode, str(idx), set_name))
				imageTensor_C = imageTensor_C.to(self.device)
				depthGTTensor = depthGTTensor.to(self.device)	# real depth

				if self.is_train and self.use_apex:
					with amp.disable_casts():
						predDepth = self.depthEstModel(imageTensor_C)[-1].detach().to('cpu')
				else:
					predDepth = self.depthEstModel(imageTensor_C)[-1].detach().to('cpu')

				# recover real depth
				predDepth = ((predDepth + 1.0) * 0.5 * (self.MAX_DEPTH_CLIP - self.MIN_DEPTH_CLIP)) + self.MIN_DEPTH_CLIP

				depthGTTensor = depthGTTensor.detach().to('cpu')
				predTensor = torch.cat((predTensor, predDepth), dim=0)
				grndTensor = torch.cat((grndTensor, depthGTTensor), dim=0)
				imgTensor = torch.cat((imgTensor, imageTensor.to('cpu')), dim=0)
				extTensor = torch.cat((extTensor, extrinsic_para), dim=0)

				if isinstance(mode, int) and self.is_train:
					eval_depth_loss = self.L1Loss(predDepth[valid_mask], depthGTTensor[valid_mask])
					total_loss += eval_depth_loss.detach().cpu()

				count += 1

			if isinstance(mode, int) and self.is_train:
				validation_loss = (total_loss / count)

			results_nyu = Result(mask_min=self.EVAL_DEPTH_MIN, mask_max=self.EVAL_DEPTH_MAX)
			results_nyu.evaluate(predTensor[1:], grndTensor[1:])
			individual_results = results_nyu.individual_results(predTensor[1:], grndTensor[1:])

			result1 = '\tabs_rel:{:.3f}, sq_rel:{:.3f}, rmse:{:.3f}, rmse_log:{:.3f}, mae:{:.3f} '.format(
					results_nyu.absrel,results_nyu.sqrel,results_nyu.rmse,results_nyu.rmselog,results_nyu.mae)
			result2 = '\t[<1.25]:{:.3f}, [<1.25^2]:{:.3f}, [<1.25^3]::{:.3f}'.format(results_nyu.delta1,results_nyu.delta2,results_nyu.delta3)

			print(result1)
			print(result2)

			if isinstance(mode, int) and self.is_train:
				self.EVAL_all_results[str(mode)] = result1 + '\t' + result2

				if validation_loss.item() < self.EVAL_best_loss:
					self.EVAL_best_loss = validation_loss.item()
					self.EVAL_best_model_epoch = mode
					self.save_models(self.model_name, mode='best')

				best_model_summary = '\tCurrent eval loss {:.4f}, current best loss {:.4f}, current best model {}\n'.format(validation_loss.item(), self.EVAL_best_loss, self.EVAL_best_model_epoch)
				print(best_model_summary)

				fn = open(self.evaluate_log, 'a')
				fn.write(result1 + '\n')
				fn.write(result2 + '\n')
				fn.write(best_model_summary + '\n')
				fn.close()

			return_dict = {}
			return_dict['rgb'] = imgTensor[1:]
			return_dict['depth_pred'] = predTensor[1:]
			return_dict['depth_gt'] = grndTensor[1:]
			return_dict['extrinsic'] = extTensor[1:]
			return_dict['ind_results'] = individual_results
			
			return return_dict
Пример #12
0
def train(cfg, local_rank, distributed, logger):
    if is_main_process():
        wandb.init(project='scene-graph',
                   entity='sgg-speaker-listener',
                   config=cfg.LISTENER)
    debug_print(logger, 'prepare training')

    model = build_detection_model(cfg)
    listener = build_listener(cfg)
    if is_main_process():
        wandb.watch(listener)

    debug_print(logger, 'end model construction')

    # modules that should be always set in eval mode
    # their eval() method should be called after model.train() is called
    eval_modules = (
        model.rpn,
        model.backbone,
        model.roi_heads.box,
    )

    fix_eval_modules(eval_modules)

    # NOTE, we slow down the LR of the layers start with the names in slow_heads
    if cfg.MODEL.ROI_RELATION_HEAD.PREDICTOR == "IMPPredictor":
        slow_heads = [
            "roi_heads.relation.box_feature_extractor",
            "roi_heads.relation.union_feature_extractor.feature_extractor",
        ]
    else:
        slow_heads = []

    # load pretrain layers to new layers
    load_mapping = {
        "roi_heads.relation.box_feature_extractor":
        "roi_heads.box.feature_extractor",
        "roi_heads.relation.union_feature_extractor.feature_extractor":
        "roi_heads.box.feature_extractor"
    }

    if cfg.MODEL.ATTRIBUTE_ON:
        load_mapping[
            "roi_heads.relation.att_feature_extractor"] = "roi_heads.attribute.feature_extractor"
        load_mapping[
            "roi_heads.relation.union_feature_extractor.att_feature_extractor"] = "roi_heads.attribute.feature_extractor"

    device = torch.device(cfg.MODEL.DEVICE)
    model.to(device)
    listener.to(device)

    num_gpus = int(
        os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
    num_batch = cfg.SOLVER.IMS_PER_BATCH
    optimizer = make_optimizer(cfg,
                               model,
                               logger,
                               slow_heads=slow_heads,
                               slow_ratio=10.0,
                               rl_factor=float(num_batch))
    listener_optimizer = make_listener_optimizer(cfg, listener)
    scheduler = make_lr_scheduler(cfg, optimizer, logger)
    listener_scheduler = None
    debug_print(logger, 'end optimizer and shcedule')
    # Initialize mixed-precision training
    use_mixed_precision = cfg.DTYPE == "float16"
    amp_opt_level = 'O1' if use_mixed_precision else 'O0'
    #listener, listener_optimizer = amp.initialize(listener, listener_optimizer, opt_level='O0')
    [model, listener], [optimizer, listener_optimizer
                        ] = amp.initialize([model, listener],
                                           [optimizer, listener_optimizer],
                                           opt_level='O1',
                                           loss_scale=1)
    model = amp.initialize(model, opt_level='O1')

    if distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[local_rank],
            output_device=local_rank,
            # this should be removed if we update BatchNorm stats
            broadcast_buffers=False,
            find_unused_parameters=True,
        )

        listener = torch.nn.parallel.DistributedDataParallel(
            listener,
            device_ids=[local_rank],
            output_device=local_rank,
            # this should be removed if we update BatchNorm stats
            broadcast_buffers=False,
            find_unused_parameters=True,
        )

    debug_print(logger, 'end distributed')
    arguments = {}
    arguments["iteration"] = 0

    output_dir = cfg.OUTPUT_DIR
    listener_dir = cfg.LISTENER_DIR
    save_to_disk = get_rank() == 0
    checkpointer = DetectronCheckpointer(cfg,
                                         model,
                                         optimizer,
                                         scheduler,
                                         output_dir,
                                         save_to_disk,
                                         custom_scheduler=True)

    listener_checkpointer = Checkpointer(listener,
                                         optimizer=listener_optimizer,
                                         save_dir=listener_dir,
                                         save_to_disk=save_to_disk,
                                         custom_scheduler=False)

    if checkpointer.has_checkpoint():
        extra_checkpoint_data = checkpointer.load(
            cfg.MODEL.PRETRAINED_DETECTOR_CKPT,
            update_schedule=cfg.SOLVER.UPDATE_SCHEDULE_DURING_LOAD)
        arguments.update(extra_checkpoint_data)
    else:
        # load_mapping is only used when we init current model from detection model.
        checkpointer.load(cfg.MODEL.PRETRAINED_DETECTOR_CKPT,
                          with_optim=False,
                          load_mapping=load_mapping)

    # if there is certain checkpoint in output_dir, load it, else load pretrained detector
    if listener_checkpointer.has_checkpoint():
        extra_listener_checkpoint_data = listener_checkpointer.load()
        amp.load_state_dict(extra_listener_checkpoint_data['amp'])
        '''
        print('Weights after load: ')
        print('****************************')
        print(listener.gnn.conv1.node_model.node_mlp_1[0].weight)
        print('****************************')
        '''
        # arguments.update(extra_listener_checkpoint_data)
    debug_print(logger, 'end load checkpointer')
    train_data_loader = make_data_loader(cfg,
                                         mode='train',
                                         is_distributed=distributed,
                                         start_iter=arguments["iteration"],
                                         ret_images=True)
    val_data_loaders = make_data_loader(cfg,
                                        mode='val',
                                        is_distributed=distributed,
                                        ret_images=True)

    debug_print(logger, 'end dataloader')
    checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD

    if cfg.SOLVER.PRE_VAL:
        logger.info("Validate before training")
        #output =  run_val(cfg, model, listener, val_data_loaders, distributed, logger)
        #print('OUTPUT: ', output)
        #(sg_loss, img_loss, sg_acc, img_acc) = output

    logger.info("Start training")
    meters = MetricLogger(delimiter="  ")
    max_iter = len(train_data_loader)
    start_iter = arguments["iteration"]
    start_training_time = time.time()
    end = time.time()

    print_first_grad = True

    listener_loss_func = torch.nn.MarginRankingLoss(margin=1, reduction='none')
    mistake_saver = None
    if is_main_process():
        ds_catalog = DatasetCatalog()
        dict_file_path = os.path.join(
            ds_catalog.DATA_DIR,
            ds_catalog.DATASETS['VG_stanford_filtered_with_attribute']
            ['dict_file'])
        ind_to_classes, ind_to_predicates = load_vg_info(dict_file_path)
        ind_to_classes = {k: v for k, v in enumerate(ind_to_classes)}
        ind_to_predicates = {k: v for k, v in enumerate(ind_to_predicates)}
        print('ind to classes:', ind_to_classes, '/n ind to predicates:',
              ind_to_predicates)
        mistake_saver = MistakeSaver(
            '/Scene-Graph-Benchmark.pytorch/filenames_masked', ind_to_classes,
            ind_to_predicates)

    #is_printed = False
    while True:
        try:
            listener_iteration = 0
            for iteration, (images, targets,
                            image_ids) in enumerate(train_data_loader,
                                                    start_iter):
                listener_optimizer.zero_grad()

                #print(f'ITERATION NUMBER: {iteration}')
                if any(len(target) < 1 for target in targets):
                    logger.error(
                        f"Iteration={iteration + 1} || Image Ids used for training {_} || targets Length={[len(target) for target in targets]}"
                    )
                if len(images) <= 1:
                    continue

                data_time = time.time() - end
                iteration = iteration + 1
                listener_iteration += 1
                arguments["iteration"] = iteration
                model.train()
                fix_eval_modules(eval_modules)
                images_list = deepcopy(images)
                images_list = to_image_list(
                    images_list, cfg.DATALOADER.SIZE_DIVISIBILITY).to(device)

                #SAVE IMAGE TO PC
                '''
                transform = transforms.Compose([
                    transforms.ToPILImage(),
                    #transforms.Resize((cfg.LISTENER.IMAGE_SIZE, cfg.LISTENER.IMAGE_SIZE)),
                    transforms.ToTensor(),
                ])
                '''
                # turn images to a uniform size
                #print('IMAGE BEFORE Transform: ', images[0], 'GPU: ', get_rank())
                '''

                if is_main_process():
                    if not is_printed:
                        transform = transforms.ToPILImage()
                        print('SAVING IMAGE')
                        img = transform(images[0].cpu())
                        print('DONE TRANSFORM')
                        img.save('image.png')
                        print('DONE SAVING IMAGE')
                        print('ids ', image_ids[0])

                '''

                for i in range(len(images)):
                    images[i] = images[i].unsqueeze(0)
                    images[i] = F.interpolate(images[i],
                                              size=(224, 224),
                                              mode='bilinear',
                                              align_corners=False)
                    images[i] = images[i].squeeze()

                images = torch.stack(images).to(device)
                #images.requires_grad_()

                targets = [target.to(device) for target in targets]

                #print('IMAGE BEFORE Model: ', images[0], 'GPU: ', get_rank())
                _, sgs = model(images_list, targets)
                #print('IMAGE AFTER Model: ', images)
                '''
                is_printed = False
                if is_main_process():
                    if not is_printed:
                        print('PRINTING OBJECTS')
                        (obj, rel_pair, rel) = sgs[0]
                        obj = torch.argmax(obj, dim=1)
                        for i in range(obj.size(0)):
                            print(f'OBJECT {i}: ', obj[i])
                        print('DONE PRINTING OBJECTS')
                        is_printed=True

                '''
                image_list = None
                sgs = collate_sgs(sgs, cfg.MODEL.DEVICE)
                ''' 

                if is_main_process():
                    if not is_printed:
                        mistake_saver.add_mistake((image_ids[0], image_ids[1]), (sgs[0], sgs[1]), 231231, 'SG') 
                        mistake_saver.toHtml('/www')
                        is_printed = True
                
                '''

                listener_loss = 0
                gap_reward = 0
                avg_acc = 0
                num_correct = 0
                score_matrix = torch.zeros((images.size(0), images.size(0)))
                # fill score matrix
                for true_index, sg in enumerate(sgs):
                    acc = 0
                    detached_sg = (sg[0].detach().requires_grad_().to(
                        torch.float32), sg[1].long(),
                                   sg[2].detach().requires_grad_().to(
                                       torch.float32))
                    #scores = listener(sg, images)
                    with amp.disable_casts():
                        scores = listener(detached_sg, images)
                    score_matrix[true_index] = scores

                #print('Score matrix:', score_matrix)
                score_matrix = score_matrix.to(device)
                # fill loss matrix
                loss_matrix = torch.zeros((2, images.size(0), images.size(0)),
                                          device=device)
                # sg centered scores
                for true_index in range(loss_matrix.size(1)):
                    row_score = score_matrix[true_index]
                    (true_scores, predicted_scores,
                     binary) = format_scores(row_score, true_index, device)
                    loss_vec = listener_loss_func(true_scores,
                                                  predicted_scores, binary)
                    loss_matrix[0][true_index] = loss_vec
                # image centered scores
                transposted_score_matrix = score_matrix.t()
                for true_index in range(loss_matrix.size(1)):
                    row_score = transposted_score_matrix[true_index]
                    (true_scores, predicted_scores,
                     binary) = format_scores(row_score, true_index, device)
                    loss_vec = listener_loss_func(true_scores,
                                                  predicted_scores, binary)
                    loss_matrix[1][true_index] = loss_vec

                print('iteration:', listener_iteration)
                sg_acc = 0
                img_acc = 0
                # calculate accuracy
                for i in range(loss_matrix.size(1)):
                    temp_sg_acc = 0
                    temp_img_acc = 0
                    for j in range(loss_matrix.size(2)):
                        if loss_matrix[0][i][i] > loss_matrix[0][i][j]:
                            temp_sg_acc += 1
                        else:
                            if cfg.LISTENER.HTML:
                                if is_main_process(
                                ) and listener_iteration >= 600 and listener_iteration % 25 == 0 and i != j:
                                    detached_sg_i = (sgs[i][0].detach(),
                                                     sgs[i][1],
                                                     sgs[i][2].detach())
                                    detached_sg_j = (sgs[j][0].detach(),
                                                     sgs[j][1],
                                                     sgs[j][2].detach())
                                    mistake_saver.add_mistake(
                                        (image_ids[i], image_ids[j]),
                                        (detached_sg_i, detached_sg_j),
                                        listener_iteration, 'SG')
                        if loss_matrix[1][i][i] > loss_matrix[1][j][i]:
                            temp_img_acc += 1
                        else:
                            if cfg.LISTENER.HTML:
                                if is_main_process(
                                ) and listener_iteration >= 600 and listener_iteration % 25 == 0 and i != j:
                                    detached_sg_i = (sgs[i][0].detach(),
                                                     sgs[i][1],
                                                     sgs[i][2].detach())
                                    detached_sg_j = (sgs[j][0].detach(),
                                                     sgs[j][1],
                                                     sgs[j][2].detach())
                                    mistake_saver.add_mistake(
                                        (image_ids[i], image_ids[j]),
                                        (detached_sg_i, detached_sg_j),
                                        listener_iteration, 'IMG')

                    temp_sg_acc = temp_sg_acc * 100 / (loss_matrix.size(1) - 1)
                    temp_img_acc = temp_img_acc * 100 / (loss_matrix.size(1) -
                                                         1)
                    sg_acc += temp_sg_acc
                    img_acc += temp_img_acc

                if cfg.LISTENER.HTML:
                    if is_main_process(
                    ) and listener_iteration % 100 == 0 and listener_iteration >= 600:
                        mistake_saver.toHtml('/www')

                sg_acc /= loss_matrix.size(1)
                img_acc /= loss_matrix.size(1)

                avg_sg_acc = torch.tensor([sg_acc]).to(device)
                avg_img_acc = torch.tensor([img_acc]).to(device)
                # reduce acc over all gpus
                avg_acc = {'sg_acc': avg_sg_acc, 'img_acc': avg_img_acc}
                avg_acc_reduced = reduce_loss_dict(avg_acc)

                sg_acc = sum(acc for acc in avg_acc_reduced['sg_acc'])
                img_acc = sum(acc for acc in avg_acc_reduced['img_acc'])

                # log acc to wadb
                if is_main_process():
                    wandb.log({
                        "Train SG Accuracy": sg_acc.item(),
                        "Train IMG Accuracy": img_acc.item()
                    })

                sg_loss = 0
                img_loss = 0

                for i in range(loss_matrix.size(0)):
                    for j in range(loss_matrix.size(1)):
                        loss_matrix[i][j][j] = 0.

                for i in range(loss_matrix.size(1)):
                    sg_loss += torch.max(loss_matrix[0][i])
                    img_loss += torch.max(loss_matrix[1][:][i])

                sg_loss = sg_loss / loss_matrix.size(1)
                img_loss = img_loss / loss_matrix.size(1)
                sg_loss = sg_loss.to(device)
                img_loss = img_loss.to(device)

                loss_dict = {'sg_loss': sg_loss, 'img_loss': img_loss}

                losses = sum(loss for loss in loss_dict.values())

                # reduce losses over all GPUs for logging purposes
                loss_dict_reduced = reduce_loss_dict(loss_dict)
                sg_loss_reduced = loss_dict_reduced['sg_loss']
                img_loss_reduced = loss_dict_reduced['img_loss']
                if is_main_process():
                    wandb.log({"Train SG Loss": sg_loss_reduced})
                    wandb.log({"Train IMG Loss": img_loss_reduced})

                losses_reduced = sum(loss
                                     for loss in loss_dict_reduced.values())
                meters.update(loss=losses_reduced, **loss_dict_reduced)

                # Note: If mixed precision is not used, this ends up doing nothing
                # Otherwise apply loss scaling for mixed-precision recipe
                losses.backward()
                #with amp.scale_loss(losses, listener_optimizer) as scaled_losses:
                #    scaled_losses.backward()

                verbose = (iteration % cfg.SOLVER.PRINT_GRAD_FREQ
                           ) == 0 or print_first_grad  # print grad or not
                print_first_grad = False
                #clip_grad_value([(n, p) for n, p in listener.named_parameters() if p.requires_grad], cfg.LISTENER.CLIP_VALUE, logger=logger, verbose=True, clip=True)
                listener_optimizer.step()

                batch_time = time.time() - end
                end = time.time()
                meters.update(time=batch_time, data=data_time)

                eta_seconds = meters.time.global_avg * (max_iter - iteration)
                eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))

                if iteration % 200 == 0 or iteration == max_iter:
                    logger.info(
                        meters.delimiter.join([
                            "eta: {eta}",
                            "iter: {iter}",
                            "{meters}",
                            "lr: {lr:.6f}",
                            "max mem: {memory:.0f}",
                        ]).format(
                            eta=eta_string,
                            iter=iteration,
                            meters=str(meters),
                            lr=listener_optimizer.param_groups[-1]["lr"],
                            memory=torch.cuda.max_memory_allocated() / 1024.0 /
                            1024.0,
                        ))

                if iteration % checkpoint_period == 0:
                    """
                    print('Model before save')
                    print('****************************')
                    print(listener.gnn.conv1.node_model.node_mlp_1[0].weight)
                    print('****************************')
                    """
                    listener_checkpointer.save(
                        "model_{:07d}".format(listener_iteration),
                        amp=amp.state_dict())
                    #listener_checkpointer.save("model_{:07d}".format(listener_iteration))

                if iteration == max_iter:
                    listener_checkpointer.save("model_final",
                                               amp=amp.state_dict())
                    #listener_checkpointer.save("model_final")

                val_result = None  # used for scheduler updating
                if cfg.SOLVER.TO_VAL and iteration % cfg.SOLVER.VAL_PERIOD == 0:
                    logger.info("Start validating")
                    val_result = run_val(cfg, model, listener,
                                         val_data_loaders, distributed, logger)
                    (sg_loss, img_loss, sg_acc, img_acc,
                     speaker_val) = val_result

                    if is_main_process():
                        wandb.log({
                            "Validation SG Accuracy": sg_acc,
                            "Validation IMG Accuracy": img_acc,
                            "Validation SG Loss": sg_loss,
                            "Validation IMG Loss": img_loss,
                            "Speaker Val": speaker_val,
                        })

        except Exception as err:
            raise (err)
            print('Dataset finished, creating new')
            train_data_loader = make_data_loader(
                cfg,
                mode='train',
                is_distributed=distributed,
                start_iter=arguments["iteration"],
                ret_images=True)

    total_training_time = time.time() - start_training_time
    total_time_str = str(datetime.timedelta(seconds=total_training_time))
    logger.info("Total training time: {} ({:.4f} s / it)".format(
        total_time_str, total_training_time / (max_iter)))
    return listener