예제 #1
0
파일: refcoco.py 프로젝트: j-min/VL-T5
    def train(self):
        if self.verbose:
            n_correct = 0
            n_total = 0
            for batch in self.val_loader:
                exists_target = batch['exists_target']

                n_correct += exists_target.sum().item()
                n_total += len(exists_target)

            print(f'Val Oracle acc: {n_correct / n_total * 100:.2f}%')

            n_correct = 0
            n_total = 0
            for batch in self.test_loader:
                exists_target = batch['exists_target']

                n_correct += exists_target.sum().item()
                n_total += len(exists_target)

            print(f'Test Oracle acc: {n_correct / n_total * 100:.2f}%')

        if self.verbose:
            loss_meter = LossMeter()

            best_valid_acc = 0.
            best_epoch = 0

            if 't5' in self.args.backbone:
                project_name = "VLT5_RefCOCOg"
            elif 'bart' in self.args.backbone:
                project_name = "VLBart_RefCOCOg"

            if self.args.RefCOCO_GT:
                project_name += '_GT'

            wandb.init(project=project_name)
            wandb.run.name = self.args.run_name
            wandb.config.update(self.args)
            wandb.watch(self.model)

            src_dir = Path(__file__).resolve().parent
            base_path = str(src_dir.parent)
            src_dir = str(src_dir)
            wandb.save(os.path.join(src_dir + "/*.py"), base_path=base_path)

        if self.args.distributed:
            dist.barrier()

        global_step = 0
        for epoch in range(self.args.epochs):
            if self.start_epoch is not None:
                epoch += self.start_epoch
            self.model.train()
            if self.args.distributed:
                self.train_loader.sampler.set_epoch(epoch)
            if self.verbose:
                pbar = tqdm(total=len(self.train_loader), ncols=120)

            epoch_results = {
                'loss': 0,
            }

            n_correct = 0
            n_total = 0

            for step_i, batch in enumerate(self.train_loader):

                batch['log_train_accuracy'] = self.args.log_train_accuracy

                if self.args.fp16 and _use_native_amp:
                    with autocast():
                        if self.args.distributed:
                            results = self.model.module.train_step(batch)
                        else:
                            results = self.model.train_step(batch)
                else:
                    if self.args.distributed:
                        results = self.model.module.train_step(batch)
                    else:
                        results = self.model.train_step(batch)

                loss = results['loss']

                if self.args.fp16 and _use_native_amp:
                    self.scaler.scale(loss).backward()
                elif self.args.fp16 and _use_apex:
                    with amp.scale_loss(loss, self.optim) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()

                loss = loss.detach()

                # Update Parameters
                if self.args.clip_grad_norm > 0:
                    if self.args.fp16 and _use_native_amp:
                        self.scaler.unscale_(self.optim)
                        torch.nn.utils.clip_grad_norm_(
                            self.model.parameters(), self.args.clip_grad_norm)
                    elif self.args.fp16 and _use_apex:
                        torch.nn.utils.clip_grad_norm_(
                            amp.master_params(self.optim),
                            self.args.clip_grad_norm)
                    else:
                        torch.nn.utils.clip_grad_norm_(
                            self.model.parameters(), self.args.clip_grad_norm)

                if self.args.fp16 and _use_native_amp:
                    self.scaler.step(self.optim)
                    self.scaler.update()
                else:
                    self.optim.step()

                if self.lr_scheduler:
                    self.lr_scheduler.step()
                for param in self.model.parameters():
                    param.grad = None

                global_step += 1

                for k, v in results.items():
                    if k in epoch_results:
                        epoch_results[k] += v.item()

                if self.lr_scheduler:
                    if version.parse(
                            torch.__version__) >= version.parse("1.4"):
                        lr = self.lr_scheduler.get_last_lr()[0]
                    else:
                        lr = self.lr_scheduler.get_lr()[0]
                else:
                    try:
                        lr = self.optim.get_lr()[0]
                    except AttributeError:
                        lr = self.args.lr

                if self.args.log_train_accuracy:
                    correct = results['correct']
                    n_correct += sum(correct)
                    n_total += len(correct)

                if self.verbose:
                    loss_meter.update(loss.item())
                    # acc_meter.update(results['acc'].item())

                    desc_str = f'Epoch {epoch} | LR {lr:.6f} | '
                    desc_str += f'Loss {loss_meter.val:.3f} |'
                    # desc_str += f' Acc {acc_meter.val:.3f} |'

                    if self.args.log_train_accuracy:
                        desc_str += f' Correct {n_correct:.0f}'
                        desc_str += f' (Acc {n_correct/n_total*100:.1f}%)'

                    pbar.set_description(desc_str)
                    pbar.update(1)

                if self.args.distributed:
                    dist.barrier()

            if self.verbose:
                pbar.close()

            if self.args.log_train_accuracy:
                train_score_dict = {'n_correct': n_correct, 'n_total': n_total}
                train_score_dict = reduce_dict(train_score_dict, self.args.gpu)

            # Validation
            # valid_score_dict = self.evaluate(self.val_loader)
            # valid_score_dict = reduce_dict(valid_score_dict, self.args.gpu)

            if self.verbose:
                if self.args.log_train_accuracy:
                    train_acc = train_score_dict[
                        'n_correct'] / train_score_dict['n_total'] * 100
                    train_n_correct = int(train_score_dict['n_correct'])
                    train_n_total = int(train_score_dict['n_total'])

                # Validation
                valid_score_dict = self.evaluate(self.val_loader)
                valid_acc = valid_score_dict['n_correct'] / valid_score_dict[
                    'n_total'] * 100
                valid_n_correct = int(valid_score_dict['n_correct'])
                valid_n_total = int(valid_score_dict['n_total'])

                if valid_acc > best_valid_acc:
                    best_valid_acc = valid_acc
                    best_epoch = epoch
                    self.save("BEST")

                log_str = ''

                if self.args.log_train_accuracy:
                    log_str += f"\nEpoch {epoch}: Train"
                    log_str += f" Acc {train_acc:.2f}% |"
                    log_str += f" # correct {train_n_correct} # total {train_n_total}"

                log_str += f"\nEpoch {epoch}: Valid"
                log_str += f" Acc {valid_acc:.2f}% |"
                log_str += f" # correct {valid_n_correct} # total {valid_n_total}"

                log_str += f"\nEpoch {best_epoch}: Best  Acc {best_valid_acc:.2f}%\n"

                wandb_log_dict = {}

                if self.args.log_train_accuracy:
                    wandb_log_dict['Train/Acc'] = train_acc

                wandb_log_dict['Valid/Acc'] = valid_acc

                wandb.log(wandb_log_dict, step=epoch)

                print(log_str)

            if self.args.distributed:
                dist.barrier()

        if self.verbose:
            self.save("LAST")

            # Test Set
            best_path = os.path.join(self.args.output, 'BEST')
            self.load(best_path)

            test_score_dict = self.evaluate(self.test_loader)
            test_acc = test_score_dict['n_correct'] / test_score_dict[
                'n_total'] * 100
            test_n_correct = int(test_score_dict['n_correct'])
            test_n_total = int(test_score_dict['n_total'])

            wandb_log_dict = {}
            wandb_log_dict['Test/Acc'] = test_acc
            wandb.log(wandb_log_dict, step=epoch)

            log_str = ''
            log_str += f"\nTest Acc {test_acc:.2f}%"
            log_str += f"\nTest # correct {test_n_correct} # total {test_n_total}"

            print(log_str)

            wandb.log({'finished': True})

        if self.args.distributed:
            dist.barrier()
예제 #2
0
파일: vcr.py 프로젝트: j-min/VL-T5
    def train(self):
        if self.verbose:
            loss_meter = LossMeter()
            qa_loss_meter = LossMeter()
            qar_loss_meter = LossMeter()

            best_valid_Q_AR = 0.
            best_epoch = 0

            if 't5' in self.args.backbone:
                if self.args.use_vision:
                    project_name = "VLT5_VCR"
                else:
                    project_name = "T5_VCR"
            elif 'bart' in self.args.backbone:
                if self.args.use_vision:
                    project_name = "VLBart_VCR"
                else:
                    project_name = "Bart_VCR"

            wandb.init(project=project_name)
            wandb.run.name = self.args.run_name
            wandb.config.update(self.args)
            wandb.watch(self.model)

            src_dir = Path(__file__).resolve().parent
            base_path = str(src_dir.parent)
            src_dir = str(src_dir)
            wandb.save(os.path.join(src_dir + "/*.py"), base_path=base_path)

        if self.args.distributed:
            dist.barrier()

        global_step = 0
        for epoch in range(self.args.epochs):
            if self.start_epoch is not None:
                epoch += self.start_epoch
            self.model.train()
            if self.args.distributed:
                self.train_loader.sampler.set_epoch(epoch)
            if self.verbose:
                pbar = tqdm(total=len(self.train_loader), ncols=200)

            epoch_results = {
                'loss': 0,
            }

            Q_A_results = 0
            QA_R_results = 0
            Q_AR_results = 0
            n_total = 0

            n_accu = 0
            train_loss = 0
            train_qa_loss = 0
            train_qar_loss = 0

            for step_i, batch in enumerate(self.train_loader):

                batch['log_train_accuracy'] = self.args.log_train_accuracy

                if self.args.fp16 and _use_native_amp:
                    with autocast():
                        if self.args.distributed:
                            results = self.model.module.train_step(batch)
                        else:
                            results = self.model.train_step(batch)
                else:
                    if self.args.distributed:
                        results = self.model.module.train_step(batch)
                    else:
                        results = self.model.train_step(batch)

                loss = results['loss']

                if self.args.gradient_accumulation_steps > 1:
                    loss /= self.args.gradient_accumulation_steps

                if self.args.fp16 and _use_native_amp:
                    self.scaler.scale(loss).backward()
                elif self.args.fp16 and _use_apex:
                    with amp.scale_loss(loss, self.optim) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()

                # Update Parameters
                if self.args.clip_grad_norm > 0:
                    if self.args.fp16 and _use_native_amp:
                        self.scaler.unscale_(self.optim)
                        torch.nn.utils.clip_grad_norm_(
                            self.model.parameters(), self.args.clip_grad_norm)
                    elif self.args.fp16 and _use_apex:
                        torch.nn.utils.clip_grad_norm_(
                            amp.master_params(self.optim),
                            self.args.clip_grad_norm)
                    else:
                        torch.nn.utils.clip_grad_norm_(
                            self.model.parameters(), self.args.clip_grad_norm)

                update = True
                if self.args.gradient_accumulation_steps > 1:
                    # if step_i == 0:
                    #     update = False
                    # elif step_i % self.args.gradient_accumulation_steps == 0 or step_i == len(self.train_loader) - 1:
                    #     update = True
                    # else:
                    #     update = False
                    update = ((step_i + 1) %
                              self.args.gradient_accumulation_steps
                              == 0) or (step_i == len(self.train_loader) - 1)
                n_accu += 1

                if update:
                    if self.args.fp16 and _use_native_amp:
                        self.scaler.step(self.optim)
                        self.scaler.update()
                    else:
                        self.optim.step()

                    if self.lr_scheduler:
                        self.lr_scheduler.step()
                    # self.model.zero_grad()
                    for param in self.model.parameters():
                        param.grad = None

                    global_step += 1

                for k, v in results.items():
                    if k in epoch_results:
                        epoch_results[k] += v.item()

                if self.lr_scheduler:
                    if version.parse(
                            torch.__version__) >= version.parse("1.4"):
                        lr = self.lr_scheduler.get_last_lr()[0]
                    else:
                        lr = self.lr_scheduler.get_lr()[0]
                else:
                    try:
                        lr = self.optim.get_lr()[0]
                    except AttributeError:
                        lr = self.args.lr

                if self.verbose:
                    train_loss += loss.detach().item()
                    train_qa_loss += results['qa_loss'].item(
                    ) / self.args.gradient_accumulation_steps
                    train_qar_loss += results['qar_loss'].item(
                    ) / self.args.gradient_accumulation_steps

                    if self.args.log_train_accuracy:
                        qa_pred = results['qa_pred']
                        qar_pred = results['qar_pred']

                        a_labels = batch['answer_labels']
                        r_labels = batch['rationale_labels']

                        Q_A_correct = a_labels == qa_pred
                        QA_R_correct = r_labels == qar_pred
                        Q_AR_correct = Q_A_correct & QA_R_correct

                        Q_A_results += sum(Q_A_correct)
                        QA_R_results += sum(QA_R_correct)
                        Q_AR_results += sum(Q_AR_correct)
                        n_total += len(qa_pred)

                    if update:
                        if self.args.gradient_accumulation_steps > 1:
                            train_loss *= self.args.gradient_accumulation_steps / n_accu
                            train_qa_loss *= self.args.gradient_accumulation_steps / n_accu
                            train_qar_loss *= self.args.gradient_accumulation_steps / n_accu

                        loss_meter.update(train_loss)
                        qa_loss_meter.update(train_qa_loss)
                        qar_loss_meter.update(train_qar_loss)
                        desc_str = f'Epoch {epoch} | LR {lr:.6f} | Steps {global_step} |'
                        desc_str += f' Loss {loss_meter.val:.3f} |'
                        desc_str += f' QA Loss {qa_loss_meter.val:.3f} |'
                        desc_str += f' QAR Loss {qar_loss_meter.val:.3f} |'

                        train_loss = 0
                        train_qa_loss = 0
                        train_qar_loss = 0
                        n_accu = 0

                        if self.args.log_train_accuracy:
                            desc_str += f' Q -> A {Q_A_results} ({Q_A_results/n_total*100:.1f}%)'
                            desc_str += f' QA -> R {QA_R_results} ({QA_R_results/n_total*100:.1f}%)'
                            desc_str += f' Q -> AR {Q_AR_results} ({Q_AR_results/n_total*100:.1f}%)'

                        pbar.set_description(desc_str)
                    pbar.update(1)

            if self.verbose:
                pbar.close()

            if self.args.log_train_accuracy:
                train_score_dict = {
                    'Q_A': Q_A_results,
                    'QA_R': QA_R_results,
                    'Q_AR': Q_AR_results,
                    'n_total': n_total
                }
                train_score_dict = reduce_dict(train_score_dict, self.args.gpu)

            # Validation
            valid_score_dict = self.evaluate_val(self.val_loader)

            if self.verbose:
                if self.args.log_train_accuracy:
                    train_Q_A = train_score_dict['Q_A'] / train_score_dict[
                        'n_total'] * 100
                    train_QA_R = train_score_dict['QA_R'] / train_score_dict[
                        'n_total'] * 100
                    train_Q_AR = train_score_dict['Q_AR'] / train_score_dict[
                        'n_total'] * 100
                    train_n_total = int(train_score_dict['n_total'])

                valid_Q_A = valid_score_dict['Q_A'] / valid_score_dict[
                    'n_total'] * 100
                valid_QA_R = valid_score_dict['QA_R'] / valid_score_dict[
                    'n_total'] * 100
                valid_Q_AR = valid_score_dict['Q_AR'] / valid_score_dict[
                    'n_total'] * 100
                valid_n_total = int(valid_score_dict['n_total'])

                if valid_Q_AR > best_valid_Q_AR:
                    best_valid_Q_AR = valid_Q_AR
                    best_epoch = epoch
                    self.save("BEST")

                log_str = ''

                if self.args.log_train_accuracy:
                    log_str += f"\nEpoch {epoch}: Train |"
                    log_str += f" # examples: {train_n_total} |"
                    log_str += f" Q -> A {train_Q_A:.2f}%"
                    log_str += f" QA -> R {train_QA_R:.2f}%"
                    log_str += f" Q -> AR {train_Q_AR:.2f}%"

                log_str += f"\nEpoch {epoch}: Valid |"
                log_str += f" # examples: {valid_n_total} |"
                log_str += f" Q -> A {valid_Q_A:.2f}%"
                log_str += f" QA -> R {valid_QA_R:.2f}%"
                log_str += f" Q -> AR {valid_Q_AR:.2f}%"

                #log_str += "\nEpoch %d: Valid Q -> AR %0.2f" % (epoch, valid_Q_AR)
                log_str += f"\nBest Epoch {best_epoch}: Q -> AR {best_valid_Q_AR:.2f}%\n"

                wandb_log_dict = {}
                # wandb_log_dict['Train/Loss'] = loss_meter.val

                if self.args.log_train_accuracy:
                    wandb_log_dict['Train/Q_A'] = train_Q_A
                    wandb_log_dict['Train/QA_R'] = train_QA_R
                    wandb_log_dict['Train/Q_AR'] = train_Q_AR

                wandb_log_dict['Valid/Q_A'] = valid_Q_A
                wandb_log_dict['Valid/QA_R'] = valid_QA_R
                wandb_log_dict['Valid/Q_AR'] = valid_Q_AR

                wandb.log(wandb_log_dict, step=epoch)

                print(log_str)

            if self.args.distributed:
                dist.barrier()

        if self.verbose:
            self.save("LAST")

        # Test Set
        best_path = os.path.join(self.args.output, 'BEST')
        self.load(best_path)

        if self.verbose:
            dump_path = os.path.join(self.args.output, 'test_submit.csv')

            print('Dumping test set results at', dump_path)
            self.evaluate_test(self.test_loader, dump_path=dump_path)
            wandb.save(dump_path, base_path=self.args.output)

            print('Done!')

            wandb.log({'finished': True})

        if self.args.distributed:
            dist.barrier()
            exit()
예제 #3
0
파일: pretrain.py 프로젝트: j-min/VL-T5
    def train(self):
        LOSSES_NAME = self.args.LOSSES_NAME

        if self.args.dry:
            results = self.evaluate_epoch(epoch=0)

        if self.verbose:
            loss_meters = [LossMeter() for _ in range(len(LOSSES_NAME))]
            best_eval_loss = 9595.

            if 't5' in self.args.backbone:
                project_name = "VLT5_Pretrain"
            elif 'bart' in self.args.backbone:
                project_name = "VLBart_Pretrain"

            wandb.init(project=project_name)
            wandb.run.name = self.args.run_name
            wandb.config.update(self.args)
            wandb.watch(self.model)

            src_dir = Path(__file__).resolve().parent
            base_path = str(src_dir.parent)
            src_dir = str(src_dir)
            wandb.save(os.path.join(src_dir + "/*.py"), base_path=base_path)

        if self.args.distributed:
            dist.barrier()

        global_step = 0
        for epoch in range(self.args.epochs):
            if self.start_epoch is not None:
                epoch += self.start_epoch
            if self.args.distributed:
                self.train_loader.sampler.set_epoch(epoch)

            # Train
            self.model.train()

            if self.verbose:
                pbar = tqdm(total=len(self.train_loader), ncols=250)

            epoch_results = {}
            for loss_name in LOSSES_NAME:
                epoch_results[loss_name] = 0.
                epoch_results[f'{loss_name}_count'] = 0

            for step_i, batch in enumerate(self.train_loader):

                if self.args.fp16 and _use_native_amp:
                    with autocast():
                        if self.args.distributed:
                            results = self.model.module.train_step(batch)
                        else:
                            results = self.model.train_step(batch)
                else:
                    if self.args.distributed:
                        results = self.model.module.train_step(batch)
                    else:
                        results = self.model.train_step(batch)

                loss = results['loss']

                if self.args.fp16 and _use_native_amp:
                    self.scaler.scale(loss).backward()
                elif self.args.fp16 and _use_apex:
                    with amp.scale_loss(loss, self.optim) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()

                loss = loss.detach()

                # Update Parameters
                if self.args.clip_grad_norm > 0:
                    if self.args.fp16 and _use_native_amp:
                        self.scaler.unscale_(self.optim)
                        torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.clip_grad_norm)
                    elif self.args.fp16 and _use_apex:
                        torch.nn.utils.clip_grad_norm_(amp.master_params(self.optim), self.args.clip_grad_norm)
                    else:
                        torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.clip_grad_norm)

                if self.args.fp16 and _use_native_amp:
                    self.scaler.step(self.optim)
                    self.scaler.update()
                else:
                    self.optim.step()

                if self.lr_scheduler:
                    self.lr_scheduler.step()

                # self.model.zero_grad()
                for param in self.model.parameters():
                    param.grad = None

                global_step += 1

                if self.lr_scheduler:
                    if version.parse(torch.__version__) >= version.parse("1.4"):
                        lr = self.lr_scheduler.get_last_lr()[0]
                    else:
                        lr = self.lr_scheduler.get_lr()[0]
                else:
                    try:
                        lr = self.optim.get_lr()[0]
                    except AttributeError:
                        lr = self.args.lr

                for k, v in results.items():
                    if k in epoch_results:
                        if isinstance(v, int):
                            epoch_results[k] += v
                        elif isinstance(v, torch.Tensor):
                            epoch_results[k] += v.item()

                if self.verbose:
                    desc_str = f'Epoch {epoch} | LR {lr:.6f} |'

                    for i, (loss_name, loss_meter) in enumerate(zip(LOSSES_NAME, loss_meters)):

                        if loss_name in results:
                            loss_meter.update(results[f'{loss_name}'] / results[f'{loss_name}_count'])
                        if len(loss_meter) > 0:
                            loss_count = epoch_results[f'{loss_name}_count']
                            desc_str += f' {loss_name} ({loss_count}) {loss_meter.val:.3f}'

                    pbar.set_description(desc_str)
                    pbar.update(1)

            if self.verbose:
                pbar.close()

            dist.barrier()

            results = reduce_dict(epoch_results, average=False)
            if self.verbose:
                train_loss = results['total_loss']
                train_loss_count = results['total_loss_count']

                avg_train_loss = train_loss / train_loss_count
                losses_str = f"Train Loss: {avg_train_loss:.3f}\n"

                for name, loss in results.items():
                    if name[-4:] == 'loss':
                        loss_count = int(results[name+'_count'])
                        if loss_count > 0:
                            avg_loss = loss/loss_count
                            losses_str += f"{name} ({loss_count}): {avg_loss:.3f} "
                            wandb.log({f'Train Loss/{name}': avg_loss}, step=epoch)

                losses_str += '\n'
                print(losses_str)

            dist.barrier()

            # Validation
            valid_results, valid_uid2ans = self.evaluate_epoch(epoch=epoch)

            valid_results = reduce_dict(valid_results, average=False)
            if self.verbose:
                valid_loss = valid_results['total_loss']
                valid_loss_count = valid_results['total_loss_count']

                avg_valid_loss = valid_loss / valid_loss_count
                losses_str = f"Valid Loss: {avg_valid_loss:.3f}\n"

                for name, loss in valid_results.items():
                    if name[-4:] == 'loss':
                        loss_count = int(valid_results[name+'_count'])
                        if loss_count > 0:
                            avg_loss = loss / loss_count
                            losses_str += f"{name} ({loss_count}): {avg_loss:.3f} "
                            wandb.log({f'Valid Loss/{name}': avg_loss}, step=epoch)

                losses_str += '\n'
                print(losses_str)

            if 'qa' in self.args.losses:
                dset2score, dset2cnt, score, cnt = self.val_loader.dataset.evaluator.evaluate(valid_uid2ans)

                if len(dset2score) == 0:
                    dset2score = {'vqa': 0, 'gqa': 0, 'visual7w': 0}
                    dset2cnt = {'vqa': 1, 'gqa': 1, 'visual7w': 1}
                    cnt = 3
                    score = 0

                dset2score = reduce_dict(dset2score, average=False)
                dset2cnt = reduce_dict(dset2cnt, average=False)
                score_cnt_dict = reduce_dict({'score': score, 'cnt': cnt}, average=False)

                if self.args.gpu == 0:
                    score = score_cnt_dict['score']
                    cnt = score_cnt_dict['cnt']
                    accu = score / cnt
                    dset2accu = {}
                    for dset in dset2cnt:
                        dset2accu[dset] = dset2score[dset] / dset2cnt[dset]
                    accu_str = "Overall QA Acc %0.4f" % (accu)
                    wandb.log({f'Valid QA Acc/Overall': accu}, step=epoch)
                    sorted_keys = sorted(dset2accu.keys())
                    for key in sorted_keys:
                        accu_str += ", %s Acc %0.4f" % (key, dset2accu[key])
                        wandb.log({f'Valid QA Acc/{key}': dset2accu[key]}, step=epoch)
                    print(accu_str)
                    accu_str += '\n\n'

            dist.barrier()

            if self.verbose:
                # Save
                if avg_valid_loss < best_eval_loss:
                    best_eval_loss = avg_valid_loss
                #     self.save("BEST_EVAL_LOSS")
                self.save("Epoch%02d" % (epoch + 1))

            dist.barrier()

        if self.verbose:
            wandb.log({'finished': True})
예제 #4
0
파일: multitask.py 프로젝트: j-min/VL-T5
    def train(self):
        if self.verbose:
            vqa_loss_meter = LossMeter()
            refcoco_loss_meter = LossMeter()
            # best_eval_loss = 9595.
            quesid2ans = {}
            best_vqa_valid = 0.
            best_vqa_epoch = 0

            # gqa
            best_gqa_valid = 0
            best_gqa_epoch = 0

            # nlvr
            best_nlvr_valid = 0
            best_nlvr_epoch = 0

            # vcr
            best_valid_Q_AR = 0
            best_vcr_epoch = 0

            # refcoco
            best_refcoco_valid = 0
            best_refcoco_epoch = 0

            # caption
            best_caption_valid = 0
            best_caption_epoch = 0

            # mmt
            best_mmt_valid = 0
            best_mmt_epoch = 0

            assert 't5' in self.args.backbone
            self.setup_wandb()

        if self.args.distributed:
            dist.barrier()

        global_step = 0
        for epoch in range(self.args.epochs):
            if self.start_epoch is not None:
                epoch += self.start_epoch
            self.model.train()

            if self.args.distributed:
                self.train_loader.set_epoch(epoch)
            if self.verbose:
                pbar = tqdm(total=len(self.train_loader), ncols=250)

            epoch_results = {
                'loss': 0.,
            }

            task_counter = {
                'vqa': 0,
                'gqa': 0,
                'nlvr': 0,
                'refcoco': 0,
                'vcr': 0,
                'caption': 0,
                'mmt': 0,
            }

            # vqa
            quesid2ans = {}
            train_acc = 0.
            # train_acc_steps = int(len(self.train_loader) * 0.05)
            # last_acc_step = 0

            # refcoco
            n_correct = 0
            n_total = 0

            for step_i, batch in enumerate(self.train_loader):

                # print(f'GPU{self.args.gpu} inside training loop')
                # print(batch)
                task = batch['task']
                # if self.verbose:
                #     print('task', task)
                task_counter[task] += 1

                batch['log_train_accuracy'] = self.args.log_train_accuracy

                # self.optim.zero_grad()
                if self.args.fp16 and _use_native_amp:
                    with autocast():
                        if self.args.distributed:
                            results = self.model.module.train_step(batch)
                        else:
                            results = self.model.train_step(batch)
                else:
                    if self.args.distributed:
                        results = self.model.module.train_step(batch)
                    else:
                        results = self.model.train_step(batch)

                loss = results['loss']

                # print(f'GPU{self.args.gpu} after loss')

                if self.args.fp16 and _use_native_amp:
                    self.scaler.scale(loss).backward()
                elif self.args.fp16 and _use_apex:
                    with amp.scale_loss(loss, self.optim) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()

                # print(f'GPU{self.args.gpu} after backward')

                loss = loss.detach()

                # Update Parameters
                if self.args.clip_grad_norm > 0:
                    if self.args.fp16 and _use_native_amp:
                        self.scaler.unscale_(self.optim)
                        torch.nn.utils.clip_grad_norm_(
                            self.model.parameters(), self.args.clip_grad_norm)
                    elif self.args.fp16 and _use_apex:
                        torch.nn.utils.clip_grad_norm_(
                            amp.master_params(self.optim),
                            self.args.clip_grad_norm)
                    else:
                        torch.nn.utils.clip_grad_norm_(
                            self.model.parameters(), self.args.clip_grad_norm)

                if self.args.fp16 and _use_native_amp:
                    self.scaler.step(self.optim)
                    self.scaler.update()
                else:
                    self.optim.step()

                if self.lr_scheduler:
                    self.lr_scheduler.step()
                for param in self.model.parameters():
                    param.grad = None

                global_step += 1

                for k, v in results.items():
                    if k in epoch_results:
                        epoch_results[k] += v.item()

                if self.lr_scheduler:
                    if version.parse(
                            torch.__version__) >= version.parse("1.4"):
                        lr = self.lr_scheduler.get_last_lr()[0]
                    else:
                        lr = self.lr_scheduler.get_lr()[0]
                else:
                    try:
                        lr = self.optim.get_lr()[0]
                    except AttributeError:
                        lr = self.args.lr

                # self.train_step_post_hook(result)

                if self.args.log_train_accuracy and task == 'refcoco':
                    correct = results['correct']
                    n_correct += sum(correct)
                    n_total += len(correct)

                if self.verbose:
                    if task == 'vqa':
                        vqa_loss_meter.update(loss.item())
                    elif task == 'refcoco':
                        refcoco_loss_meter.update(loss.item())

                    desc_str = f'Epoch {epoch} | LR {lr:.6f}'

                    desc_str += f" |"
                    if 'vqa' in self.args.tasks:
                        desc_str += f" VQA {task_counter['vqa']}"
                    if 'gqa' in self.args.tasks:
                        desc_str += f" GQA {task_counter['gqa']}"
                    if 'nlvr' in self.args.tasks:
                        desc_str += f" NLVR {task_counter['nlvr']}"
                    if 'vcr' in self.args.tasks:
                        desc_str += f" VCR {task_counter['vcr']}"
                    if 'refcoco' in self.args.tasks:
                        desc_str += f" RefCOCOg {task_counter['refcoco']}"
                    if 'caption' in self.args.tasks:
                        desc_str += f" COCO {task_counter['caption']}"
                    if 'mmt' in self.args.tasks:
                        desc_str += f" MMT {task_counter['mmt']}"

                    if len(vqa_loss_meter) > 0:
                        desc_str += f' | VQA Loss {vqa_loss_meter.val:4f}'
                    if len(refcoco_loss_meter) > 0:
                        desc_str += f' | RefCOCOg Loss {refcoco_loss_meter.val:.3f}'

                    if self.args.log_train_accuracy and n_total > 0:
                        desc_str += f' | RefCOCOg Acc'
                        desc_str += f' Correct {n_correct:.0f}'
                        desc_str += f' (Acc {n_correct/n_total*100:.1f}%)'

                    pbar.set_description(desc_str)
                    pbar.update(1)

                if self.args.distributed:
                    dist.barrier()

            if self.verbose:
                pbar.close()
                self.save("Epoch%02d" % (epoch + 1))

            if self.args.log_train_accuracy:
                train_score_dict = {'n_correct': n_correct, 'n_total': n_total}
                train_score_dict = reduce_dict(train_score_dict, self.args.gpu)

            if self.verbose:
                # Validation
                log_str = ''
                wandb_log_dict = {}

                if 'vqa' in self.args.tasks:
                    # VQA
                    vqa_val_loader = self.val_loader['vqa']
                    score_dict = self.vqa_evaluate(vqa_val_loader)
                    valid_score = score_dict['topk_score'] * 100.
                    valid_score_raw = score_dict['overall']
                    if valid_score_raw > best_vqa_valid or epoch == 0:
                        best_vqa_valid = valid_score_raw
                        best_vqa_epoch = epoch
                        # self.save("VQA_BEST")
                    log_str += f"VQA"
                    log_str += "\nEpoch %d: Valid Raw %0.2f Topk %0.2f" % (
                        epoch, valid_score_raw, valid_score)
                    log_str += "\nEpoch %d: Best Raw %0.2f\n" % (
                        best_vqa_epoch, best_vqa_valid)
                    wandb_log_dict['VQA/Valid/score'] = valid_score
                    wandb_log_dict['VQA/Valid/raw_score'] = score_dict[
                        'overall']
                if 'gqa' in self.args.tasks:
                    # GQA
                    gqa_val_loader = self.val_loader['gqa']
                    valid_score = self.gqa_evaluate(gqa_val_loader) * 100
                    if valid_score > best_gqa_valid or epoch == 0:
                        best_gqa_valid = valid_score
                        best_gqa_epoch = epoch
                    wandb_log_dict['GQA/Valid/Acc'] = valid_score
                    log_str += f"GQA"
                    log_str += "\nEpoch %d: Valid %0.2f" % (epoch, valid_score)
                    log_str += "\nEpoch %d: Best %0.2f\n" % (best_gqa_epoch,
                                                             best_gqa_valid)
                if 'nlvr' in self.args.tasks:
                    # NLVR
                    nlvr_val_loader = self.val_loader['nlvr']
                    valid_score_dict = self.nlvr_evaluate(nlvr_val_loader)
                    valid_acc = valid_score_dict['accuracy'] * 100.
                    if valid_acc > best_nlvr_valid or epoch == 0:
                        best_nlvr_valid = valid_acc
                        best_nlvr_epoch = epoch
                    wandb_log_dict['NLVR/Valid/Acc'] = valid_acc
                    log_str += f"NLVR"
                    log_str += "\nEpoch %d: Valid %0.2f" % (epoch, valid_acc)
                    log_str += "\nEpoch %d: Best %0.2f\n" % (best_nlvr_epoch,
                                                             best_nlvr_valid)

                if 'vcr' in self.args.tasks:
                    # VCR
                    vcr_val_loader = self.val_loader['vcr']
                    valid_score_dict = self.vcr_evaluate(vcr_val_loader)
                    valid_Q_A = valid_score_dict['Q_A'] / valid_score_dict[
                        'n_total'] * 100
                    valid_QA_R = valid_score_dict['QA_R'] / valid_score_dict[
                        'n_total'] * 100
                    valid_Q_AR = valid_score_dict['Q_AR'] / valid_score_dict[
                        'n_total'] * 100
                    valid_n_total = int(valid_score_dict['n_total'])
                    if valid_Q_AR > best_valid_Q_AR or epoch == 0:
                        best_valid_Q_AR = valid_Q_AR
                        best_vcr_epoch = epoch
                    wandb_log_dict['VCR/Valid/Q_A'] = valid_Q_A
                    wandb_log_dict['VCR/Valid/QA_R'] = valid_QA_R
                    wandb_log_dict['VCR/Valid/Q_AR'] = valid_Q_AR
                    log_str += f"VCR"
                    log_str += "\nEpoch %d: Valid %0.2f" % (epoch, valid_Q_AR)
                    log_str += "\nEpoch %d: Best %0.2f\n" % (best_vcr_epoch,
                                                             best_valid_Q_AR)

                if 'refcoco' in self.args.tasks:
                    # RefCOCO
                    refcoco_val_loader = self.val_loader['refcoco']
                    if self.args.log_train_accuracy:
                        train_acc = train_score_dict[
                            'n_correct'] / train_score_dict['n_total'] * 100
                        train_n_correct = int(train_score_dict['n_correct'])
                        train_n_total = int(train_score_dict['n_total'])
                    valid_score_dict = self.refcoco_evaluate(
                        refcoco_val_loader)
                    valid_acc = valid_score_dict[
                        'n_correct'] / valid_score_dict['n_total'] * 100
                    valid_n_correct = int(valid_score_dict['n_correct'])
                    valid_n_total = int(valid_score_dict['n_total'])
                    if valid_acc > best_refcoco_valid or epoch == 0:
                        best_refcoco_valid = valid_acc
                        best_refcoco_epoch = epoch
                    if self.args.log_train_accuracy:
                        wandb_log_dict['RefCOCO/Train/Acc'] = train_acc
                    wandb_log_dict['RefCOCO/Valid/Acc'] = valid_acc
                    log_str += f"RefCOCOg"
                    if self.args.log_train_accuracy:
                        log_str += f"\nEpoch {epoch}: Train"
                        log_str += f" Acc {train_acc:.2f}% |"
                        log_str += f" # correct {train_n_correct} # total {train_n_total}"
                    log_str += f"\nEpoch {epoch}: Valid"
                    log_str += f" Acc {valid_acc:.2f}% |"
                    log_str += f" # correct {valid_n_correct} # total {valid_n_total}"
                    log_str += f"\nEpoch {best_refcoco_epoch}: Best Acc {best_refcoco_valid:.2f}%\n"

                if 'caption' in self.args.tasks:
                    # COCO Caption
                    caption_val_loader = self.val_loader['caption']
                    valid_results = self.caption_evaluate(caption_val_loader)
                    valid_score = valid_results['CIDEr'] * 100
                    if valid_score > best_caption_valid or epoch == 0:
                        best_caption_valid = valid_score
                        best_caption_epoch = epoch
                    for score_name, score in valid_results.items():
                        wandb_log_dict[
                            f'Caption/Valid/{score_name}'] = score * 100
                    log_str += f"COCO Caption"
                    log_str += "\nEpoch %d: Valid CIDEr %0.2f" % (epoch,
                                                                  valid_score)
                    log_str += "\nEpoch %d: Best %0.2f\n" % (
                        best_caption_epoch, best_caption_valid)

                if 'mmt' in self.args.tasks:
                    # MMT
                    mmt_val_loader = self.val_loader['mmt']
                    valid_results = self.mmt_evaluate(mmt_val_loader)
                    valid_score = valid_results['BLEU']
                    if valid_score > best_mmt_valid:
                        best_mmt_valid = valid_score
                        best_mmt_epoch = epoch
                    for score_name, score in valid_results.items():
                        wandb_log_dict[f'MMT/Valid/{score_name}'] = score
                    log_str += f"Multi30K En-De"
                    log_str += "\nEpoch %d: Valid BLEU %0.2f" % (epoch,
                                                                 valid_score)
                    log_str += "\nEpoch %d: Best %0.2f\n" % (best_mmt_epoch,
                                                             best_mmt_valid)

                wandb.log(wandb_log_dict, step=epoch)

                print(log_str)

            if self.args.distributed:
                dist.barrier()

        # Test Set
        if self.verbose:
            self.save("LAST")

            log_str = ''
            wandb_log_dict = {}

            if 'vqa' in self.args.tasks:
                # VQA
                vqa_test_loader = self.test_loader['vqa']
                evaluator = vqa_test_loader.evaluator
                dump_path = os.path.join(self.args.output,
                                         'karpathy_test_predict.json')
                quesid2ans = self.vqa_predict(vqa_test_loader, dump_path)
                wandb.save(dump_path, base_path=self.args.output)

                acc_dict_all = evaluator.evaluate_raw(quesid2ans)
                acc_dict_answerable = evaluator.evaluate_raw(
                    quesid2ans, is_topk_optimal=True)
                acc_dict_unanswerable = evaluator.evaluate_raw(
                    quesid2ans, is_topk_optimal=False)

                wandb_log_dict['VQA/Test/overall'] = acc_dict_all['overall']
                wandb_log_dict['VQA/Test/topk_optimal'] = acc_dict_answerable[
                    'overall']
                wandb_log_dict[
                    'VQA/Test/topk_not_optimal'] = acc_dict_unanswerable[
                        'overall']

                vqa_submit_test_loader = self.test_loader['vqa_submit']
                dump_path = os.path.join(self.args.output, 'vqa_submit.json')
                self.vqa_predict(vqa_submit_test_loader, dump_path=dump_path)
                wandb.save(dump_path, base_path=self.args.output)

            if 'nlvr' in self.args.tasks:
                # NLVR
                nlvr_test_loader = self.test_loader['nlvr']
                dump_path = os.path.join(self.args.output, 'nlvr_submit.csv')
                test_score_dict = self.nlvr_evaluate(nlvr_test_loader,
                                                     dump_path=dump_path)
                wandb.save(dump_path, base_path=self.args.output)
                for score_name, score in test_score_dict.items():
                    wandb_log_dict[f'NLVR/Test/{score_name}'] = score * 100.
            if 'refcoco' in self.args.tasks:
                # RefCOCO
                refcoco_test_loader = self.test_loader['refcoco']
                test_score_dict = self.refcoco_evaluate(refcoco_test_loader)
                test_acc = test_score_dict['n_correct'] / test_score_dict[
                    'n_total'] * 100
                test_n_correct = int(test_score_dict['n_correct'])
                test_n_total = int(test_score_dict['n_total'])
                wandb_log_dict['RefCOCO/test/Acc'] = test_acc
                log_str = 'RefCOCOg'
                log_str += f"\nTest Acc {test_acc:.2f}%"
                log_str += f"\nTest # correct {test_n_correct} # total {test_n_total}"
            if 'caption' in self.args.tasks:
                # COCO Caption
                caption_test_loader = self.test_loader['caption']
                test_results = self.caption_evaluate(caption_test_loader)
                for score_name, score in test_results.items():
                    wandb_log_dict[f'Caption/Test/{score_name}'] = score

            if 'mmt' in self.args.tasks:
                # MMT
                mmt_test2016_loader = self.test_loader['mmt_test2016']
                mmt_test2017_loader = self.test_loader['mmt_test2017']
                mmt_test2018_loader = self.test_loader['mmt_test2018']
                for loader in [
                        mmt_test2016_loader, mmt_test2017_loader,
                        mmt_test2018_loader
                ]:
                    split = loader.dataset.source
                    dump_path = os.path.join(self.args.output,
                                             f'submit_{split}_raw.txt')
                    test_results = self.mmt_evaluate(loader,
                                                     dump_path=dump_path)
                    for score_name, score in test_results.items():
                        wandb_log_dict[f'MMT/{split}/{score_name}'] = score
                    log_str += f'{split} set results\n'
                    log_str += pformat(test_results)

            print(log_str)
            wandb.log(wandb_log_dict, step=self.args.epochs)

            wandb.log({'finished': True})

        if self.args.distributed:
            dist.barrier()
            exit()