Ejemplo n.º 1
0
    def evaluate_epoch(self, epoch):
        LOSSES_NAME = self.args.LOSSES_NAME

        epoch_results = {}
        for loss_name in LOSSES_NAME:
            epoch_results[loss_name] = 0.
            epoch_results[f'{loss_name}_count'] = 0

        uid2ans = {}

        self.model.eval()
        with torch.no_grad():
            if self.verbose:
                loss_meter = LossMeter()
                loss_meters = [LossMeter() for _ in range(len(LOSSES_NAME))]

                pbar = tqdm(total=len(self.val_loader), ncols=250)

            for step_i, batch in enumerate(self.val_loader):

                if self.args.distributed:
                    results = self.model.module.valid_step(batch)
                else:
                    results = self.model.valid_step(batch)

                if 'qa' in self.args.losses:
                    qa_pred = results['qa_pred']
                    for uid, ans in zip(batch['uid'], qa_pred):
                        uid2ans[uid] = ans

                for k, v in results.items():
                    if k in epoch_results:
                        if isinstance(v, int):
                            epoch_results[k] += v
                        elif isinstance(v, torch.Tensor):
                            epoch_results[k] += v.item()

                if self.verbose:
                    desc_str = f'Valid Epoch {epoch} |'
                    for i, (loss_name, loss_meter) in enumerate(zip(LOSSES_NAME, loss_meters)):

                        if loss_name in results:
                            loss_meter.update(results[f'{loss_name}'] / results[f'{loss_name}_count'])
                        if len(loss_meter) > 0:
                            loss_count = epoch_results[f'{loss_name}_count']
                            desc_str += f' {loss_name} ({loss_count}) {loss_meter.val:.3f}'

                    pbar.set_description(desc_str)
                    pbar.update(1)
                dist.barrier()

            if self.verbose:
                pbar.close()
            dist.barrier()

            if 'qa' not in self.args.losses:
                uid2ans = None

            return epoch_results, uid2ans
Ejemplo n.º 2
0
    def __init__(self, args):
        self.args = args
        self.args.n_datasets = len(args.data)
        self.modelPath = Path('checkpoints') / args.expName

        self.logger = create_output_dir(args, self.modelPath)
        self.data = [DatasetSet(d, args.seq_len, args) for d in args.data]

        self.losses_recon = [
            LossMeter(f'recon {i}') for i in range(self.args.n_datasets)
        ]
        self.loss_total = LossMeter('total')

        self.evals_recon = [
            LossMeter(f'recon {i}') for i in range(self.args.n_datasets)
        ]
        self.eval_total = LossMeter('eval total')

        self.start_epoch = 0

        #torch.manual_seed(args.seed)
        #torch.cuda.manual_seed(args.seed)

        #get the pretrained model checkpoints
        checkpoint = args.checkpoint.parent.glob(args.checkpoint.name +
                                                 '_*.pth')
        checkpoint = [c for c in checkpoint
                      if extract_id(c) in args.decoder][0]

        model_args = torch.load(args.checkpoint.parent / 'args.pth')[0]

        self.encoder = Encoder(model_args)
        self.decoder = WaveNet(model_args)

        self.encoder = Encoder(model_args)
        self.encoder.load_state_dict(torch.load(checkpoint)['encoder_state'])

        #encoder freeze
        for param in self.encoder.parameters():
            param.requires_grad = False
            #self.logger.debug(f'encoder at start: {param}')

        self.decoder = WaveNet(model_args)
        self.decoder.load_state_dict(torch.load(checkpoint)['decoder_state'])

        #decoder freeze
        for param in self.decoder.layers[:-args.decoder_update].parameters():
            param.requires_grad = False
            #self.logger.debug(f'decoder at start: {param}')

        self.encoder = torch.nn.DataParallel(self.encoder).cuda()
        self.decoder = torch.nn.DataParallel(self.decoder).cuda()
        self.model_optimizer = optim.Adam(chain(self.encoder.parameters(),
                                                self.decoder.parameters()),
                                          lr=args.lr)

        self.lr_manager = torch.optim.lr_scheduler.ExponentialLR(
            self.model_optimizer, args.lr_decay)
        self.lr_manager.step()
Ejemplo n.º 3
0
    def __init__(self, args):
        self.args = args
        self.args.n_datasets = len(self.args.data)
        self.expPath = Path('checkpoints') / args.expName

        torch.manual_seed(args.seed)
        torch.cuda.manual_seed(args.seed)

        self.logger = create_output_dir(args, self.expPath)
        self.data = [DatasetSet(d, args.seq_len, args) for d in args.data]

        self.losses_recon = [
            LossMeter(f'recon {i}') for i in range(self.args.n_datasets)
        ]
        self.loss_total = LossMeter('total')

        self.evals_recon = [
            LossMeter(f'recon {i}') for i in range(self.args.n_datasets)
        ]
        self.eval_total = LossMeter('eval total')

        self.encoder = Encoder(args)
        self.decoder = WaveNet(args)

        assert args.checkpoint, 'you MUST pass a checkpoint for the encoder'

        if args.continue_training:
            checkpoint_args_path = os.path.dirname(
                args.checkpoint) + '/args.pth'
            checkpoint_args = torch.load(checkpoint_args_path)

            self.start_epoch = checkpoint_args[-1] + 1
        else:
            self.start_epoch = 0

        states = torch.load(args.checkpoint)
        self.encoder.load_state_dict(states['encoder_state'])
        if args.continue_training:
            self.decoder.load_state_dict(states['decoder_state'])
        self.logger.info('Loaded checkpoint parameters')

        self.encoder = torch.nn.DataParallel(self.encoder).cuda()
        self.decoder = torch.nn.DataParallel(self.decoder).cuda()

        self.model_optimizer = optim.Adam(self.decoder.parameters(),
                                          lr=args.lr)

        if args.continue_training:
            self.model_optimizer.load_state_dict(
                states['model_optimizer_state'])

        self.lr_manager = torch.optim.lr_scheduler.ExponentialLR(
            self.model_optimizer, args.lr_decay)
        self.lr_manager.last_epoch = self.start_epoch
        self.lr_manager.step()
Ejemplo n.º 4
0
class Trainer:
    def __init__(self, args):
        self.args = args
        self.args.n_datasets = len(self.args.data)
        self.expPath = Path('checkpoints') / args.expName

        torch.manual_seed(args.seed)
        torch.cuda.manual_seed(args.seed)

        self.logger = create_output_dir(args, self.expPath)
        self.data = [DatasetSet(d, args.seq_len, args) for d in args.data]

        self.losses_recon = [
            LossMeter(f'recon {i}') for i in range(self.args.n_datasets)
        ]
        self.loss_total = LossMeter('total')

        self.evals_recon = [
            LossMeter(f'recon {i}') for i in range(self.args.n_datasets)
        ]
        self.eval_total = LossMeter('eval total')

        self.encoder = Encoder(args)
        self.decoder = WaveNet(args)

        assert args.checkpoint, 'you MUST pass a checkpoint for the encoder'

        if args.continue_training:
            checkpoint_args_path = os.path.dirname(
                args.checkpoint) + '/args.pth'
            checkpoint_args = torch.load(checkpoint_args_path)

            self.start_epoch = checkpoint_args[-1] + 1
        else:
            self.start_epoch = 0

        states = torch.load(args.checkpoint)
        self.encoder.load_state_dict(states['encoder_state'])
        if args.continue_training:
            self.decoder.load_state_dict(states['decoder_state'])
        self.logger.info('Loaded checkpoint parameters')

        self.encoder = torch.nn.DataParallel(self.encoder).cuda()
        self.decoder = torch.nn.DataParallel(self.decoder).cuda()

        self.model_optimizer = optim.Adam(self.decoder.parameters(),
                                          lr=args.lr)

        if args.continue_training:
            self.model_optimizer.load_state_dict(
                states['model_optimizer_state'])

        self.lr_manager = torch.optim.lr_scheduler.ExponentialLR(
            self.model_optimizer, args.lr_decay)
        self.lr_manager.last_epoch = self.start_epoch
        self.lr_manager.step()

    def eval_batch(self, x, x_aug, dset_num):
        x, x_aug = x.float(), x_aug.float()

        z = self.encoder(x)
        y = self.decoder(x, z)

        recon_loss = cross_entropy_loss(y, x)
        self.evals_recon[dset_num].add(recon_loss.data.cpu().numpy().mean())

        total_loss = recon_loss.mean().data.item()
        self.eval_total.add(total_loss)

        return total_loss

    def train_batch(self, x, x_aug, dset_num):
        x, x_aug = x.float(), x_aug.float()

        # optimize G - reconstructs well
        z = self.encoder(x_aug)
        z = z.detach()  # stop gradients
        y = self.decoder(x, z)

        recon_loss = cross_entropy_loss(y, x)
        self.losses_recon[dset_num].add(recon_loss.data.cpu().numpy().mean())

        loss = recon_loss.mean()
        self.model_optimizer.zero_grad()
        loss.backward()
        if self.args.grad_clip is not None:
            clip_grad_value_(self.decoder.parameters(), self.args.grad_clip)
        self.model_optimizer.step()

        self.loss_total.add(loss.data.item())

        return loss.data.item()

    def train_epoch(self, epoch):
        for meter in self.losses_recon:
            meter.reset()
        self.loss_total.reset()

        self.encoder.eval()
        self.decoder.train()

        n_batches = self.args.epoch_len

        with tqdm(total=n_batches,
                  desc='Train epoch %d' % epoch) as train_enum:
            for batch_num in range(n_batches):
                if self.args.short and batch_num == 3:
                    break

                dset_num = batch_num % self.args.n_datasets

                x, x_aug = next(self.data[dset_num].train_iter)

                x = wrap(x)
                x_aug = wrap(x_aug)
                batch_loss = self.train_batch(x, x_aug, dset_num)

                train_enum.set_description(
                    f'Train (loss: {batch_loss:.2f}) epoch {epoch}')
                train_enum.update()

    def evaluate_epoch(self, epoch):
        for meter in self.evals_recon:
            meter.reset()
        self.eval_total.reset()

        self.encoder.eval()
        self.decoder.eval()

        n_batches = int(np.ceil(self.args.epoch_len / 10))

        with tqdm(total=n_batches) as valid_enum, \
                torch.no_grad():
            for batch_num in range(n_batches):
                if self.args.short and batch_num == 10:
                    break

                dset_num = batch_num % self.args.n_datasets

                x, x_aug = next(self.data[dset_num].valid_iter)

                x = wrap(x)
                x_aug = wrap(x_aug)
                batch_loss = self.eval_batch(x, x_aug, dset_num)

                valid_enum.set_description(
                    f'Test (loss: {batch_loss:.2f}) epoch {epoch}')
                valid_enum.update()

    @staticmethod
    def format_losses(meters):
        losses = [meter.summarize_epoch() for meter in meters]
        return ', '.join('{:.4f}'.format(x) for x in losses)

    def train_losses(self):
        meters = [*self.losses_recon]
        return self.format_losses(meters)

    def eval_losses(self):
        meters = [*self.evals_recon]
        return self.format_losses(meters)

    def train(self):
        best_eval = float('inf')

        # Begin!
        for epoch in range(self.start_epoch,
                           self.start_epoch + self.args.epochs):
            self.logger.info(
                f'Starting epoch, Rank {self.args.rank}, Dataset: {self.args.data[self.args.rank]}'
            )
            self.train_epoch(epoch)
            self.evaluate_epoch(epoch)

            self.logger.info(
                f'Epoch %s Rank {self.args.rank} - Train loss: (%s), Test loss (%s)',
                epoch, self.train_losses(), self.eval_losses())
            self.lr_manager.step()
            val_loss = self.eval_total.summarize_epoch()

            if val_loss < best_eval:
                self.save_model(f'bestmodel_{self.args.rank}.pth')
                best_eval = val_loss

            if not self.args.per_epoch:
                self.save_model(f'lastmodel_{self.args.rank}.pth')
            else:
                self.save_model(f'lastmodel_{epoch}_rank_{self.args.rank}.pth')

            torch.save([self.args, epoch], '%s/args.pth' % self.expPath)

            self.logger.debug('Ended epoch')

    def save_model(self, filename):
        save_path = self.expPath / filename

        states = torch.load(self.args.checkpoint)

        torch.save(
            {
                'encoder_state': states['encoder_state'],
                'decoder_state': self.decoder.module.state_dict(),
                'model_optimizer_state': self.model_optimizer.state_dict(),
                'dataset': self.args.rank,
            }, save_path)

        self.logger.debug(f'Saved model to {save_path}')
Ejemplo n.º 5
0
Archivo: vqa.py Proyecto: j-min/VL-T5
    def train(self):
        if self.verbose:
            loss_meter = LossMeter()
            best_valid = 0.
            best_epoch = 0

            if 't5' in self.args.backbone:
                if self.args.use_vision:
                    project_name = "VLT5_VQA"
                else:
                    project_name = "T5_VQA"
            elif 'bart' in self.args.backbone:
                if self.args.use_vision:
                    project_name = "VLBart_VQA"
                else:
                    project_name = "Bart_VQA"

            wandb.init(project=project_name)
            wandb.run.name = self.args.run_name
            wandb.config.update(self.args)
            wandb.watch(self.model)

            src_dir = Path(__file__).resolve().parent
            base_path = str(src_dir.parent)
            src_dir = str(src_dir)
            wandb.save(os.path.join(src_dir + "/*.py"), base_path=base_path)

        if self.args.distributed:
            dist.barrier()

        global_step = 0
        for epoch in range(self.args.epochs):
            if self.start_epoch is not None:
                epoch += self.start_epoch
            self.model.train()
            if self.args.distributed:
                self.train_loader.sampler.set_epoch(epoch)
            if self.verbose:
                pbar = tqdm(total=len(self.train_loader), ncols=120)

            epoch_results = {
                'loss': 0.,

            }

            quesid2ans = {}

            for step_i, batch in enumerate(self.train_loader):

                if self.args.fp16 and _use_native_amp:
                    with autocast():
                        if self.args.distributed:
                            results = self.model.module.train_step(batch)
                        else:
                            results = self.model.train_step(batch)
                else:
                    if self.args.distributed:
                        results = self.model.module.train_step(batch)
                    else:
                        results = self.model.train_step(batch)

                loss = results['loss']

                if self.args.fp16 and _use_native_amp:
                    self.scaler.scale(loss).backward()
                elif self.args.fp16 and _use_apex:
                    with amp.scale_loss(loss, self.optim) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()

                loss = loss.detach()

                # Update Parameters
                if self.args.clip_grad_norm > 0:
                    if self.args.fp16 and _use_native_amp:
                        self.scaler.unscale_(self.optim)
                        torch.nn.utils.clip_grad_norm_(
                            self.model.parameters(), self.args.clip_grad_norm)
                    elif self.args.fp16 and _use_apex:
                        torch.nn.utils.clip_grad_norm_(amp.master_params(
                            self.optim), self.args.clip_grad_norm)
                    else:
                        torch.nn.utils.clip_grad_norm_(
                            self.model.parameters(), self.args.clip_grad_norm)

                if self.args.fp16 and _use_native_amp:
                    self.scaler.step(self.optim)
                    self.scaler.update()
                else:
                    self.optim.step()

                if self.lr_scheduler:
                    self.lr_scheduler.step()
                for param in self.model.parameters():
                    param.grad = None

                global_step += 1

                for k, v in results.items():
                    if k in epoch_results:
                        epoch_results[k] += v.item()

                if self.lr_scheduler:
                    if version.parse(torch.__version__) >= version.parse("1.4"):
                        lr = self.lr_scheduler.get_last_lr()[0]
                    else:
                        lr = self.lr_scheduler.get_lr()[0]
                else:
                    try:
                        lr = self.optim.get_lr()[0]
                    except AttributeError:
                        lr = self.args.lr

                if self.verbose:
                    loss_meter.update(loss.item())
                    desc_str = f'Epoch {epoch} | LR {lr:.6f}'
                    desc_str += f' | Loss {loss_meter.val:4f}'

                    pbar.set_description(desc_str)
                    pbar.update(1)

                if self.args.distributed:
                    dist.barrier()

            if self.verbose:
                pbar.close()

            # Validation
            score_dict = self.evaluate(self.val_loader)

            if self.verbose:
                valid_score = score_dict['topk_score'] * 100.
                valid_score_raw = score_dict['overall']
                if valid_score_raw > best_valid or epoch == 0:
                    best_valid = valid_score_raw
                    best_epoch = epoch
                    self.save("BEST")

                log_str = ''
                log_str += "\nEpoch %d: Valid Raw %0.2f Topk %0.2f" % (epoch, valid_score_raw, valid_score)
                log_str += "\nEpoch %d: Best Raw %0.2f\n" % (best_epoch, best_valid)

                wandb_log_dict = {}
                wandb_log_dict['Train/Loss'] = epoch_results['loss'] / len(self.train_loader)

                wandb_log_dict['Valid/score'] = valid_score

                wandb_log_dict['Valid/raw_score'] = score_dict['overall']
                for qtype, score in score_dict['perQuestionType'].items():
                    wandb_log_dict[f'Valid_Qtypes/{qtype}'] = score
                for atype, score in score_dict['perAnswerType'].items():
                    if atype == 'yes/no':
                        atype = 'yes_no'
                    wandb_log_dict[f'Valid_Atypes/{atype}'] = score

                wandb.log(wandb_log_dict, step=epoch)
                print(log_str)

            if self.args.distributed:
                dist.barrier()

        if self.verbose:
            self.save("LAST")

        # Test Set
        best_path = os.path.join(self.args.output, 'BEST')
        self.load(best_path)

        quesid2ans = self.predict(self.test_loader)

        if self.verbose:
            evaluator = self.test_loader.evaluator
            score_dict = evaluator.evaluate(quesid2ans)

            evaluator.dump_result(quesid2ans)

            acc_dict_all = evaluator.evaluate_raw(quesid2ans)
            acc_dict_answerable = evaluator.evaluate_raw(quesid2ans, is_topk_optimal=True)
            acc_dict_unanswerable = evaluator.evaluate_raw(quesid2ans, is_topk_optimal=False)

            wandb_log_dict = {}
            wandb_log_dict['Test/overall'] = acc_dict_all['overall']
            wandb_log_dict['Test/topk_optimal'] = acc_dict_answerable['overall']
            wandb_log_dict['Test/topk_not_optimal'] = acc_dict_unanswerable['overall']

            for qtype, score in acc_dict_all['perQuestionType'].items():
                wandb_log_dict[f'Test_Qtypes/{qtype}'] = score
            for atype, score in acc_dict_all['perAnswerType'].items():
                if atype == 'yes/no':
                    atype = 'yes_no'
                wandb_log_dict[f'Test_Atypes/{atype}'] = score

            print(wandb_log_dict)
            wandb.log(wandb_log_dict)

        if self.args.submit:
            dump_path = os.path.join(self.args.output, 'submit.json')
            self.predict(self.submit_test_loader, dump_path)

            wandb.save(dump_path, base_path=self.args.output)
            wandb.log({'finished': True})

        if self.args.distributed:
            dist.barrier()
            exit()
Ejemplo n.º 6
0
    def __init__(self, args):
        self.args = args
        self.args.n_datasets = len(self.args.data)
        self.expPath = Path('checkpoints') / args.expName

        torch.manual_seed(args.seed)
        torch.cuda.manual_seed(args.seed)

        self.logger = create_output_dir(args, self.expPath)
        self.data = [DatasetSet(d, args.seq_len, args) for d in args.data]
        assert not args.distributed or len(self.data) == int(
            os.environ['WORLD_SIZE']
        ), "Number of datasets must match number of nodes"

        self.losses_recon = [
            LossMeter(f'recon {i}') for i in range(self.args.n_datasets)
        ]
        self.loss_d_right = LossMeter('d')
        self.loss_total = LossMeter('total')

        self.evals_recon = [
            LossMeter(f'recon {i}') for i in range(self.args.n_datasets)
        ]
        self.eval_d_right = LossMeter('eval d')
        self.eval_total = LossMeter('eval total')

        self.encoder = Encoder(args)
        self.decoder = WaveNet(args)
        self.discriminator = ZDiscriminator(args)

        if args.checkpoint:
            checkpoint_args_path = os.path.dirname(
                args.checkpoint) + '/args.pth'
            checkpoint_args = torch.load(checkpoint_args_path)

            self.start_epoch = checkpoint_args[-1] + 1
            states = torch.load(args.checkpoint)

            self.encoder.load_state_dict(states['encoder_state'])
            self.decoder.load_state_dict(states['decoder_state'])
            self.discriminator.load_state_dict(states['discriminator_state'])

            self.logger.info('Loaded checkpoint parameters')
        else:
            self.start_epoch = 0

        if args.distributed:
            self.encoder.cuda()
            self.encoder = torch.nn.parallel.DistributedDataParallel(
                self.encoder)
            self.discriminator.cuda()
            self.discriminator = torch.nn.parallel.DistributedDataParallel(
                self.discriminator)
            self.logger.info('Created DistributedDataParallel')
        else:
            self.encoder = torch.nn.DataParallel(self.encoder).cuda()
            self.discriminator = torch.nn.DataParallel(
                self.discriminator).cuda()
        self.decoder = torch.nn.DataParallel(self.decoder).cuda()

        self.model_optimizer = optim.Adam(chain(self.encoder.parameters(),
                                                self.decoder.parameters()),
                                          lr=args.lr)
        self.d_optimizer = optim.Adam(self.discriminator.parameters(),
                                      lr=args.lr)

        if args.checkpoint and args.load_optimizer:
            self.model_optimizer.load_state_dict(
                states['model_optimizer_state'])
            self.d_optimizer.load_state_dict(states['d_optimizer_state'])

        self.lr_manager = torch.optim.lr_scheduler.ExponentialLR(
            self.model_optimizer, args.lr_decay)
        self.lr_manager.last_epoch = self.start_epoch
        self.lr_manager.step()
Ejemplo n.º 7
0
class Trainer:
    def __init__(self, args):
        self.args = args
        self.args.n_datasets = len(self.args.data)
        self.expPath = Path('checkpoints') / args.expName

        torch.manual_seed(args.seed)
        torch.cuda.manual_seed(args.seed)

        self.logger = create_output_dir(args, self.expPath)
        self.data = [DatasetSet(d, args.seq_len, args) for d in args.data]
        assert not args.distributed or len(self.data) == int(
            os.environ['WORLD_SIZE']
        ), "Number of datasets must match number of nodes"

        self.losses_recon = [
            LossMeter(f'recon {i}') for i in range(self.args.n_datasets)
        ]
        self.loss_d_right = LossMeter('d')
        self.loss_total = LossMeter('total')

        self.evals_recon = [
            LossMeter(f'recon {i}') for i in range(self.args.n_datasets)
        ]
        self.eval_d_right = LossMeter('eval d')
        self.eval_total = LossMeter('eval total')

        self.encoder = Encoder(args)
        self.decoder = WaveNet(args)
        self.discriminator = ZDiscriminator(args)

        if args.checkpoint:
            checkpoint_args_path = os.path.dirname(
                args.checkpoint) + '/args.pth'
            checkpoint_args = torch.load(checkpoint_args_path)

            self.start_epoch = checkpoint_args[-1] + 1
            states = torch.load(args.checkpoint)

            self.encoder.load_state_dict(states['encoder_state'])
            self.decoder.load_state_dict(states['decoder_state'])
            self.discriminator.load_state_dict(states['discriminator_state'])

            self.logger.info('Loaded checkpoint parameters')
        else:
            self.start_epoch = 0

        if args.distributed:
            self.encoder.cuda()
            self.encoder = torch.nn.parallel.DistributedDataParallel(
                self.encoder)
            self.discriminator.cuda()
            self.discriminator = torch.nn.parallel.DistributedDataParallel(
                self.discriminator)
            self.logger.info('Created DistributedDataParallel')
        else:
            self.encoder = torch.nn.DataParallel(self.encoder).cuda()
            self.discriminator = torch.nn.DataParallel(
                self.discriminator).cuda()
        self.decoder = torch.nn.DataParallel(self.decoder).cuda()

        self.model_optimizer = optim.Adam(chain(self.encoder.parameters(),
                                                self.decoder.parameters()),
                                          lr=args.lr)
        self.d_optimizer = optim.Adam(self.discriminator.parameters(),
                                      lr=args.lr)

        if args.checkpoint and args.load_optimizer:
            self.model_optimizer.load_state_dict(
                states['model_optimizer_state'])
            self.d_optimizer.load_state_dict(states['d_optimizer_state'])

        self.lr_manager = torch.optim.lr_scheduler.ExponentialLR(
            self.model_optimizer, args.lr_decay)
        self.lr_manager.last_epoch = self.start_epoch
        self.lr_manager.step()

    def eval_batch(self, x, x_aug, dset_num):
        x, x_aug = x.float(), x_aug.float()

        z = self.encoder(x)
        y = self.decoder(x, z)
        z_logits = self.discriminator(z)

        z_classification = torch.max(z_logits, dim=1)[1]

        z_accuracy = (z_classification == dset_num).float().mean()

        self.eval_d_right.add(z_accuracy.data.item())

        # discriminator_right = F.cross_entropy(z_logits, dset_num).mean()
        discriminator_right = F.cross_entropy(
            z_logits,
            torch.tensor([dset_num] * x.size(0)).long().cuda()).mean()
        recon_loss = cross_entropy_loss(y, x)

        self.evals_recon[dset_num].add(recon_loss.data.cpu().numpy().mean())

        total_loss = discriminator_right.data.item() * self.args.d_lambda + \
                     recon_loss.mean().data.item()

        self.eval_total.add(total_loss)

        return total_loss

    def train_batch(self, x, x_aug, dset_num):
        x, x_aug = x.float(), x_aug.float()

        # Optimize D - discriminator right
        z = self.encoder(x)
        z_logits = self.discriminator(z)
        discriminator_right = F.cross_entropy(
            z_logits,
            torch.tensor([dset_num] * x.size(0)).long().cuda()).mean()
        loss = discriminator_right * self.args.d_lambda
        self.d_optimizer.zero_grad()
        loss.backward()
        if self.args.grad_clip is not None:
            clip_grad_value_(self.discriminator.parameters(),
                             self.args.grad_clip)

        self.d_optimizer.step()

        # optimize G - reconstructs well, discriminator wrong
        z = self.encoder(x_aug)
        y = self.decoder(x, z)
        z_logits = self.discriminator(z)
        discriminator_wrong = -F.cross_entropy(
            z_logits,
            torch.tensor([dset_num] * x.size(0)).long().cuda()).mean()

        if not (-100 < discriminator_right.data.item() < 100):
            self.logger.debug(f'z_logits: {z_logits.detach().cpu().numpy()}')
            self.logger.debug(f'dset_num: {dset_num}')

        recon_loss = cross_entropy_loss(y, x)
        self.losses_recon[dset_num].add(recon_loss.data.cpu().numpy().mean())

        loss = (recon_loss.mean() + self.args.d_lambda * discriminator_wrong)

        self.model_optimizer.zero_grad()
        loss.backward()
        if self.args.grad_clip is not None:
            clip_grad_value_(self.encoder.parameters(), self.args.grad_clip)
            clip_grad_value_(self.decoder.parameters(), self.args.grad_clip)
        self.model_optimizer.step()

        self.loss_total.add(loss.data.item())

        return loss.data.item()

    def train_epoch(self, epoch):
        for meter in self.losses_recon:
            meter.reset()
        self.loss_d_right.reset()
        self.loss_total.reset()

        self.encoder.train()
        self.decoder.train()
        self.discriminator.train()

        n_batches = self.args.epoch_len

        with tqdm(total=n_batches,
                  desc='Train epoch %d' % epoch) as train_enum:
            for batch_num in range(n_batches):
                if self.args.short and batch_num == 3:
                    break

                if self.args.distributed:
                    assert self.args.rank < self.args.n_datasets, "No. of workers must be equal to #dataset"
                    # dset_num = (batch_num + self.args.rank) % self.args.n_datasets
                    dset_num = self.args.rank
                else:
                    dset_num = batch_num % self.args.n_datasets

                x, x_aug = next(self.data[dset_num].train_iter)

                x = wrap(x)
                x_aug = wrap(x_aug)
                batch_loss = self.train_batch(x, x_aug, dset_num)

                train_enum.set_description(
                    f'Train (loss: {batch_loss:.2f}) epoch {epoch}')
                train_enum.update()

    def evaluate_epoch(self, epoch):
        for meter in self.evals_recon:
            meter.reset()
        self.eval_d_right.reset()
        self.eval_total.reset()

        self.encoder.eval()
        self.decoder.eval()
        self.discriminator.eval()

        n_batches = int(np.ceil(self.args.epoch_len / 10))

        with tqdm(total=n_batches) as valid_enum, \
                torch.no_grad():
            for batch_num in range(n_batches):
                if self.args.short and batch_num == 10:
                    break

                if self.args.distributed:
                    assert self.args.rank < self.args.n_datasets, "No. of workers must be equal to #dataset"
                    dset_num = self.args.rank
                else:
                    dset_num = batch_num % self.args.n_datasets

                x, x_aug = next(self.data[dset_num].valid_iter)

                x = wrap(x)
                x_aug = wrap(x_aug)
                batch_loss = self.eval_batch(x, x_aug, dset_num)

                valid_enum.set_description(
                    f'Test (loss: {batch_loss:.2f}) epoch {epoch}')
                valid_enum.update()

    @staticmethod
    def format_losses(meters):
        losses = [meter.summarize_epoch() for meter in meters]
        return ', '.join('{:.4f}'.format(x) for x in losses)

    def train_losses(self):
        meters = [*self.losses_recon, self.loss_d_right]
        return self.format_losses(meters)

    def eval_losses(self):
        meters = [*self.evals_recon, self.eval_d_right]
        return self.format_losses(meters)

    def train(self):
        best_eval = float('inf')

        # Begin!
        for epoch in range(self.start_epoch,
                           self.start_epoch + self.args.epochs):
            self.logger.info(
                f'Starting epoch, Rank {self.args.rank}, Dataset: {self.args.data[self.args.rank]}'
            )
            self.train_epoch(epoch)
            self.evaluate_epoch(epoch)

            self.logger.info(
                f'Epoch %s Rank {self.args.rank} - Train loss: (%s), Test loss (%s)',
                epoch, self.train_losses(), self.eval_losses())
            self.lr_manager.step()
            val_loss = self.eval_total.summarize_epoch()

            if val_loss < best_eval:
                self.save_model(f'bestmodel_{self.args.rank}.pth')
                best_eval = val_loss

            if not self.args.per_epoch:
                self.save_model(f'lastmodel_{self.args.rank}.pth')
            else:
                self.save_model(f'lastmodel_{epoch}_rank_{self.args.rank}.pth')

            if self.args.is_master:
                torch.save([self.args, epoch], '%s/args.pth' % self.expPath)

            self.logger.debug('Ended epoch')

    def save_model(self, filename):
        save_path = self.expPath / filename

        torch.save(
            {
                'encoder_state': self.encoder.module.state_dict(),
                'decoder_state': self.decoder.module.state_dict(),
                'discriminator_state': self.discriminator.module.state_dict(),
                'model_optimizer_state': self.model_optimizer.state_dict(),
                'dataset': self.args.rank,
                'd_optimizer_state': self.d_optimizer.state_dict()
            }, save_path)

        self.logger.debug(f'Saved model to {save_path}')
Ejemplo n.º 8
0
    def train(self):
        if self.verbose:
            n_correct = 0
            n_total = 0
            for batch in self.val_loader:
                exists_target = batch['exists_target']

                n_correct += exists_target.sum().item()
                n_total += len(exists_target)

            print(f'Val Oracle acc: {n_correct / n_total * 100:.2f}%')

            n_correct = 0
            n_total = 0
            for batch in self.test_loader:
                exists_target = batch['exists_target']

                n_correct += exists_target.sum().item()
                n_total += len(exists_target)

            print(f'Test Oracle acc: {n_correct / n_total * 100:.2f}%')

        if self.verbose:
            loss_meter = LossMeter()

            best_valid_acc = 0.
            best_epoch = 0

            if 't5' in self.args.backbone:
                project_name = "VLT5_RefCOCOg"
            elif 'bart' in self.args.backbone:
                project_name = "VLBart_RefCOCOg"

            if self.args.RefCOCO_GT:
                project_name += '_GT'

            wandb.init(project=project_name)
            wandb.run.name = self.args.run_name
            wandb.config.update(self.args)
            wandb.watch(self.model)

            src_dir = Path(__file__).resolve().parent
            base_path = str(src_dir.parent)
            src_dir = str(src_dir)
            wandb.save(os.path.join(src_dir + "/*.py"), base_path=base_path)

        if self.args.distributed:
            dist.barrier()

        global_step = 0
        for epoch in range(self.args.epochs):
            if self.start_epoch is not None:
                epoch += self.start_epoch
            self.model.train()
            if self.args.distributed:
                self.train_loader.sampler.set_epoch(epoch)
            if self.verbose:
                pbar = tqdm(total=len(self.train_loader), ncols=120)

            epoch_results = {
                'loss': 0,
            }

            n_correct = 0
            n_total = 0

            for step_i, batch in enumerate(self.train_loader):

                batch['log_train_accuracy'] = self.args.log_train_accuracy

                if self.args.fp16 and _use_native_amp:
                    with autocast():
                        if self.args.distributed:
                            results = self.model.module.train_step(batch)
                        else:
                            results = self.model.train_step(batch)
                else:
                    if self.args.distributed:
                        results = self.model.module.train_step(batch)
                    else:
                        results = self.model.train_step(batch)

                loss = results['loss']

                if self.args.fp16 and _use_native_amp:
                    self.scaler.scale(loss).backward()
                elif self.args.fp16 and _use_apex:
                    with amp.scale_loss(loss, self.optim) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()

                loss = loss.detach()

                # Update Parameters
                if self.args.clip_grad_norm > 0:
                    if self.args.fp16 and _use_native_amp:
                        self.scaler.unscale_(self.optim)
                        torch.nn.utils.clip_grad_norm_(
                            self.model.parameters(), self.args.clip_grad_norm)
                    elif self.args.fp16 and _use_apex:
                        torch.nn.utils.clip_grad_norm_(
                            amp.master_params(self.optim),
                            self.args.clip_grad_norm)
                    else:
                        torch.nn.utils.clip_grad_norm_(
                            self.model.parameters(), self.args.clip_grad_norm)

                if self.args.fp16 and _use_native_amp:
                    self.scaler.step(self.optim)
                    self.scaler.update()
                else:
                    self.optim.step()

                if self.lr_scheduler:
                    self.lr_scheduler.step()
                for param in self.model.parameters():
                    param.grad = None

                global_step += 1

                for k, v in results.items():
                    if k in epoch_results:
                        epoch_results[k] += v.item()

                if self.lr_scheduler:
                    if version.parse(
                            torch.__version__) >= version.parse("1.4"):
                        lr = self.lr_scheduler.get_last_lr()[0]
                    else:
                        lr = self.lr_scheduler.get_lr()[0]
                else:
                    try:
                        lr = self.optim.get_lr()[0]
                    except AttributeError:
                        lr = self.args.lr

                if self.args.log_train_accuracy:
                    correct = results['correct']
                    n_correct += sum(correct)
                    n_total += len(correct)

                if self.verbose:
                    loss_meter.update(loss.item())
                    # acc_meter.update(results['acc'].item())

                    desc_str = f'Epoch {epoch} | LR {lr:.6f} | '
                    desc_str += f'Loss {loss_meter.val:.3f} |'
                    # desc_str += f' Acc {acc_meter.val:.3f} |'

                    if self.args.log_train_accuracy:
                        desc_str += f' Correct {n_correct:.0f}'
                        desc_str += f' (Acc {n_correct/n_total*100:.1f}%)'

                    pbar.set_description(desc_str)
                    pbar.update(1)

                if self.args.distributed:
                    dist.barrier()

            if self.verbose:
                pbar.close()

            if self.args.log_train_accuracy:
                train_score_dict = {'n_correct': n_correct, 'n_total': n_total}
                train_score_dict = reduce_dict(train_score_dict, self.args.gpu)

            # Validation
            # valid_score_dict = self.evaluate(self.val_loader)
            # valid_score_dict = reduce_dict(valid_score_dict, self.args.gpu)

            if self.verbose:
                if self.args.log_train_accuracy:
                    train_acc = train_score_dict[
                        'n_correct'] / train_score_dict['n_total'] * 100
                    train_n_correct = int(train_score_dict['n_correct'])
                    train_n_total = int(train_score_dict['n_total'])

                # Validation
                valid_score_dict = self.evaluate(self.val_loader)
                valid_acc = valid_score_dict['n_correct'] / valid_score_dict[
                    'n_total'] * 100
                valid_n_correct = int(valid_score_dict['n_correct'])
                valid_n_total = int(valid_score_dict['n_total'])

                if valid_acc > best_valid_acc:
                    best_valid_acc = valid_acc
                    best_epoch = epoch
                    self.save("BEST")

                log_str = ''

                if self.args.log_train_accuracy:
                    log_str += f"\nEpoch {epoch}: Train"
                    log_str += f" Acc {train_acc:.2f}% |"
                    log_str += f" # correct {train_n_correct} # total {train_n_total}"

                log_str += f"\nEpoch {epoch}: Valid"
                log_str += f" Acc {valid_acc:.2f}% |"
                log_str += f" # correct {valid_n_correct} # total {valid_n_total}"

                log_str += f"\nEpoch {best_epoch}: Best  Acc {best_valid_acc:.2f}%\n"

                wandb_log_dict = {}

                if self.args.log_train_accuracy:
                    wandb_log_dict['Train/Acc'] = train_acc

                wandb_log_dict['Valid/Acc'] = valid_acc

                wandb.log(wandb_log_dict, step=epoch)

                print(log_str)

            if self.args.distributed:
                dist.barrier()

        if self.verbose:
            self.save("LAST")

            # Test Set
            best_path = os.path.join(self.args.output, 'BEST')
            self.load(best_path)

            test_score_dict = self.evaluate(self.test_loader)
            test_acc = test_score_dict['n_correct'] / test_score_dict[
                'n_total'] * 100
            test_n_correct = int(test_score_dict['n_correct'])
            test_n_total = int(test_score_dict['n_total'])

            wandb_log_dict = {}
            wandb_log_dict['Test/Acc'] = test_acc
            wandb.log(wandb_log_dict, step=epoch)

            log_str = ''
            log_str += f"\nTest Acc {test_acc:.2f}%"
            log_str += f"\nTest # correct {test_n_correct} # total {test_n_total}"

            print(log_str)

            wandb.log({'finished': True})

        if self.args.distributed:
            dist.barrier()
Ejemplo n.º 9
0
    def train(self):
        if self.verbose:
            loss_meter = LossMeter()

            best_valid = 0.
            best_epoch = 0

            if 't5' in self.args.backbone:
                if self.args.use_vision:
                    project_name = "VLT5_GQA"
                else:
                    project_name = "T5_GQA"
            elif 'bart' in self.args.backbone:
                if self.args.use_vision:
                    project_name = "VLBart_GQA"
                else:
                    project_name = "Bart_GQA"

            wandb.init(project=project_name)
            wandb.run.name = self.args.run_name
            wandb.config.update(self.args)
            wandb.watch(self.model)

            src_dir = Path(__file__).resolve().parent
            base_path = str(src_dir.parent)
            src_dir = str(src_dir)
            wandb.save(os.path.join(src_dir + "/*.py"), base_path=base_path)

        if self.args.distributed:
            dist.barrier()

        # torch.autograd.set_detect_anomaly(True)

        # print(f'GPU{self.args.gpu} before training starts')

        global_step = 0
        for epoch in range(self.args.epochs):
            if self.start_epoch is not None:
                epoch += self.start_epoch
            self.model.train()
            if self.args.distributed:
                self.train_loader.sampler.set_epoch(epoch)
            if self.verbose:
                pbar = tqdm(total=len(self.train_loader), ncols=120)

            epoch_results = {
                'loss': 0.,
            }

            quesid2ans = {}
            train_acc = 0.
            train_acc_steps = int(len(self.train_loader) * 0.05)
            last_acc_step = 0

            # print(f'GPU{self.args.gpu} before training loop')

            for step_i, batch in enumerate(self.train_loader):

                # print(f'GPU{self.args.gpu} inside training loop')
                # print(batch)

                # self.optim.zero_grad()
                if self.args.fp16 and _use_native_amp:
                    with autocast():
                        if self.args.distributed:
                            results = self.model.module.train_step(batch)
                        else:
                            results = self.model.train_step(batch)
                else:
                    if self.args.distributed:
                        results = self.model.module.train_step(batch)
                    else:
                        results = self.model.train_step(batch)

                loss = results['loss']

                # print(f'GPU{self.args.gpu} after loss')

                if self.args.fp16 and _use_native_amp:
                    self.scaler.scale(loss).backward()
                elif self.args.fp16 and _use_apex:
                    with amp.scale_loss(loss, self.optim) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()

                # print(f'GPU{self.args.gpu} after backward')

                loss = loss.detach()

                # Update Parameters
                if self.args.clip_grad_norm > 0:
                    if self.args.fp16 and _use_native_amp:
                        self.scaler.unscale_(self.optim)
                        torch.nn.utils.clip_grad_norm_(
                            self.model.parameters(), self.args.clip_grad_norm)
                    elif self.args.fp16 and _use_apex:
                        torch.nn.utils.clip_grad_norm_(
                            amp.master_params(self.optim),
                            self.args.clip_grad_norm)
                    else:
                        torch.nn.utils.clip_grad_norm_(
                            self.model.parameters(), self.args.clip_grad_norm)

                if self.args.fp16 and _use_native_amp:
                    self.scaler.step(self.optim)
                    self.scaler.update()
                else:
                    self.optim.step()

                if self.lr_scheduler:
                    self.lr_scheduler.step()
                # self.model.zero_grad()
                for param in self.model.parameters():
                    param.grad = None

                global_step += 1

                for k, v in results.items():
                    if k in epoch_results:
                        epoch_results[k] += v.item()

                if self.lr_scheduler:
                    if version.parse(
                            torch.__version__) >= version.parse("1.4"):
                        lr = self.lr_scheduler.get_last_lr()[0]
                    else:
                        lr = self.lr_scheduler.get_lr()[0]
                else:
                    try:
                        lr = self.optim.get_lr()[0]
                    except AttributeError:
                        lr = self.args.lr

                if self.verbose:
                    loss_meter.update(loss.item())
                    desc_str = f'Epoch {epoch} | LR {lr:.6f}'
                    desc_str += f' | Loss {loss_meter.val:4f}'

                    pbar.set_description(desc_str)
                    pbar.update(1)

                if self.args.distributed:
                    dist.barrier()

            if self.verbose:
                pbar.close()

                log_str = ''

                # Validation
                valid_score = self.evaluate(self.val_loader) * 100.
                if valid_score > best_valid:
                    best_valid = valid_score
                    best_epoch = epoch
                    self.save("BEST")

                log_str += "\nEpoch %d: Testdev %0.2f" % (epoch, valid_score)
                log_str += "\nEpoch %d: Best %0.2f\n" % (best_epoch,
                                                         best_valid)

                wandb_log_dict = {}
                wandb_log_dict['Train/Loss'] = epoch_results['loss'] / len(
                    self.train_loader)

                wandb_log_dict['Testdev/score'] = valid_score

                wandb.log(wandb_log_dict, step=epoch)

                print(log_str)

            if self.args.distributed:
                dist.barrier()

        if self.verbose:
            self.save("LAST")

            # Test Set
            best_path = os.path.join(self.args.output, 'BEST')
            self.load(best_path)

            dump_path = os.path.join(self.args.output, 'submit.json')
            self.predict(self.test_loader, dump_path=dump_path)

            wandb.save(dump_path, base_path=self.args.output)

            wandb.log({'finished': True})

        if self.args.distributed:
            dist.barrier()
            exit()
Ejemplo n.º 10
0
Archivo: vcr.py Proyecto: j-min/VL-T5
    def train(self):
        if self.verbose:
            loss_meter = LossMeter()
            qa_loss_meter = LossMeter()
            qar_loss_meter = LossMeter()

            best_valid_Q_AR = 0.
            best_epoch = 0

            if 't5' in self.args.backbone:
                if self.args.use_vision:
                    project_name = "VLT5_VCR"
                else:
                    project_name = "T5_VCR"
            elif 'bart' in self.args.backbone:
                if self.args.use_vision:
                    project_name = "VLBart_VCR"
                else:
                    project_name = "Bart_VCR"

            wandb.init(project=project_name)
            wandb.run.name = self.args.run_name
            wandb.config.update(self.args)
            wandb.watch(self.model)

            src_dir = Path(__file__).resolve().parent
            base_path = str(src_dir.parent)
            src_dir = str(src_dir)
            wandb.save(os.path.join(src_dir + "/*.py"), base_path=base_path)

        if self.args.distributed:
            dist.barrier()

        global_step = 0
        for epoch in range(self.args.epochs):
            if self.start_epoch is not None:
                epoch += self.start_epoch
            self.model.train()
            if self.args.distributed:
                self.train_loader.sampler.set_epoch(epoch)
            if self.verbose:
                pbar = tqdm(total=len(self.train_loader), ncols=200)

            epoch_results = {
                'loss': 0,
            }

            Q_A_results = 0
            QA_R_results = 0
            Q_AR_results = 0
            n_total = 0

            n_accu = 0
            train_loss = 0
            train_qa_loss = 0
            train_qar_loss = 0

            for step_i, batch in enumerate(self.train_loader):

                batch['log_train_accuracy'] = self.args.log_train_accuracy

                if self.args.fp16 and _use_native_amp:
                    with autocast():
                        if self.args.distributed:
                            results = self.model.module.train_step(batch)
                        else:
                            results = self.model.train_step(batch)
                else:
                    if self.args.distributed:
                        results = self.model.module.train_step(batch)
                    else:
                        results = self.model.train_step(batch)

                loss = results['loss']

                if self.args.gradient_accumulation_steps > 1:
                    loss /= self.args.gradient_accumulation_steps

                if self.args.fp16 and _use_native_amp:
                    self.scaler.scale(loss).backward()
                elif self.args.fp16 and _use_apex:
                    with amp.scale_loss(loss, self.optim) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()

                # Update Parameters
                if self.args.clip_grad_norm > 0:
                    if self.args.fp16 and _use_native_amp:
                        self.scaler.unscale_(self.optim)
                        torch.nn.utils.clip_grad_norm_(
                            self.model.parameters(), self.args.clip_grad_norm)
                    elif self.args.fp16 and _use_apex:
                        torch.nn.utils.clip_grad_norm_(
                            amp.master_params(self.optim),
                            self.args.clip_grad_norm)
                    else:
                        torch.nn.utils.clip_grad_norm_(
                            self.model.parameters(), self.args.clip_grad_norm)

                update = True
                if self.args.gradient_accumulation_steps > 1:
                    # if step_i == 0:
                    #     update = False
                    # elif step_i % self.args.gradient_accumulation_steps == 0 or step_i == len(self.train_loader) - 1:
                    #     update = True
                    # else:
                    #     update = False
                    update = ((step_i + 1) %
                              self.args.gradient_accumulation_steps
                              == 0) or (step_i == len(self.train_loader) - 1)
                n_accu += 1

                if update:
                    if self.args.fp16 and _use_native_amp:
                        self.scaler.step(self.optim)
                        self.scaler.update()
                    else:
                        self.optim.step()

                    if self.lr_scheduler:
                        self.lr_scheduler.step()
                    # self.model.zero_grad()
                    for param in self.model.parameters():
                        param.grad = None

                    global_step += 1

                for k, v in results.items():
                    if k in epoch_results:
                        epoch_results[k] += v.item()

                if self.lr_scheduler:
                    if version.parse(
                            torch.__version__) >= version.parse("1.4"):
                        lr = self.lr_scheduler.get_last_lr()[0]
                    else:
                        lr = self.lr_scheduler.get_lr()[0]
                else:
                    try:
                        lr = self.optim.get_lr()[0]
                    except AttributeError:
                        lr = self.args.lr

                if self.verbose:
                    train_loss += loss.detach().item()
                    train_qa_loss += results['qa_loss'].item(
                    ) / self.args.gradient_accumulation_steps
                    train_qar_loss += results['qar_loss'].item(
                    ) / self.args.gradient_accumulation_steps

                    if self.args.log_train_accuracy:
                        qa_pred = results['qa_pred']
                        qar_pred = results['qar_pred']

                        a_labels = batch['answer_labels']
                        r_labels = batch['rationale_labels']

                        Q_A_correct = a_labels == qa_pred
                        QA_R_correct = r_labels == qar_pred
                        Q_AR_correct = Q_A_correct & QA_R_correct

                        Q_A_results += sum(Q_A_correct)
                        QA_R_results += sum(QA_R_correct)
                        Q_AR_results += sum(Q_AR_correct)
                        n_total += len(qa_pred)

                    if update:
                        if self.args.gradient_accumulation_steps > 1:
                            train_loss *= self.args.gradient_accumulation_steps / n_accu
                            train_qa_loss *= self.args.gradient_accumulation_steps / n_accu
                            train_qar_loss *= self.args.gradient_accumulation_steps / n_accu

                        loss_meter.update(train_loss)
                        qa_loss_meter.update(train_qa_loss)
                        qar_loss_meter.update(train_qar_loss)
                        desc_str = f'Epoch {epoch} | LR {lr:.6f} | Steps {global_step} |'
                        desc_str += f' Loss {loss_meter.val:.3f} |'
                        desc_str += f' QA Loss {qa_loss_meter.val:.3f} |'
                        desc_str += f' QAR Loss {qar_loss_meter.val:.3f} |'

                        train_loss = 0
                        train_qa_loss = 0
                        train_qar_loss = 0
                        n_accu = 0

                        if self.args.log_train_accuracy:
                            desc_str += f' Q -> A {Q_A_results} ({Q_A_results/n_total*100:.1f}%)'
                            desc_str += f' QA -> R {QA_R_results} ({QA_R_results/n_total*100:.1f}%)'
                            desc_str += f' Q -> AR {Q_AR_results} ({Q_AR_results/n_total*100:.1f}%)'

                        pbar.set_description(desc_str)
                    pbar.update(1)

            if self.verbose:
                pbar.close()

            if self.args.log_train_accuracy:
                train_score_dict = {
                    'Q_A': Q_A_results,
                    'QA_R': QA_R_results,
                    'Q_AR': Q_AR_results,
                    'n_total': n_total
                }
                train_score_dict = reduce_dict(train_score_dict, self.args.gpu)

            # Validation
            valid_score_dict = self.evaluate_val(self.val_loader)

            if self.verbose:
                if self.args.log_train_accuracy:
                    train_Q_A = train_score_dict['Q_A'] / train_score_dict[
                        'n_total'] * 100
                    train_QA_R = train_score_dict['QA_R'] / train_score_dict[
                        'n_total'] * 100
                    train_Q_AR = train_score_dict['Q_AR'] / train_score_dict[
                        'n_total'] * 100
                    train_n_total = int(train_score_dict['n_total'])

                valid_Q_A = valid_score_dict['Q_A'] / valid_score_dict[
                    'n_total'] * 100
                valid_QA_R = valid_score_dict['QA_R'] / valid_score_dict[
                    'n_total'] * 100
                valid_Q_AR = valid_score_dict['Q_AR'] / valid_score_dict[
                    'n_total'] * 100
                valid_n_total = int(valid_score_dict['n_total'])

                if valid_Q_AR > best_valid_Q_AR:
                    best_valid_Q_AR = valid_Q_AR
                    best_epoch = epoch
                    self.save("BEST")

                log_str = ''

                if self.args.log_train_accuracy:
                    log_str += f"\nEpoch {epoch}: Train |"
                    log_str += f" # examples: {train_n_total} |"
                    log_str += f" Q -> A {train_Q_A:.2f}%"
                    log_str += f" QA -> R {train_QA_R:.2f}%"
                    log_str += f" Q -> AR {train_Q_AR:.2f}%"

                log_str += f"\nEpoch {epoch}: Valid |"
                log_str += f" # examples: {valid_n_total} |"
                log_str += f" Q -> A {valid_Q_A:.2f}%"
                log_str += f" QA -> R {valid_QA_R:.2f}%"
                log_str += f" Q -> AR {valid_Q_AR:.2f}%"

                #log_str += "\nEpoch %d: Valid Q -> AR %0.2f" % (epoch, valid_Q_AR)
                log_str += f"\nBest Epoch {best_epoch}: Q -> AR {best_valid_Q_AR:.2f}%\n"

                wandb_log_dict = {}
                # wandb_log_dict['Train/Loss'] = loss_meter.val

                if self.args.log_train_accuracy:
                    wandb_log_dict['Train/Q_A'] = train_Q_A
                    wandb_log_dict['Train/QA_R'] = train_QA_R
                    wandb_log_dict['Train/Q_AR'] = train_Q_AR

                wandb_log_dict['Valid/Q_A'] = valid_Q_A
                wandb_log_dict['Valid/QA_R'] = valid_QA_R
                wandb_log_dict['Valid/Q_AR'] = valid_Q_AR

                wandb.log(wandb_log_dict, step=epoch)

                print(log_str)

            if self.args.distributed:
                dist.barrier()

        if self.verbose:
            self.save("LAST")

        # Test Set
        best_path = os.path.join(self.args.output, 'BEST')
        self.load(best_path)

        if self.verbose:
            dump_path = os.path.join(self.args.output, 'test_submit.csv')

            print('Dumping test set results at', dump_path)
            self.evaluate_test(self.test_loader, dump_path=dump_path)
            wandb.save(dump_path, base_path=self.args.output)

            print('Done!')

            wandb.log({'finished': True})

        if self.args.distributed:
            dist.barrier()
            exit()
Ejemplo n.º 11
0
    def train(self):
        if self.verbose:
            loss_meter = LossMeter()
            best_valid = 0.

            from torch.utils.tensorboard import SummaryWriter
            self.writer = SummaryWriter(log_dir=self.args.log_dir)

            hparam_dict = {}
            for k, v in self.args.__dict__.items():
                if type(v) in [int, float, str, bool, torch.Tensor]:
                    hparam_dict[k] = v
            metric_dict = {}

            self.writer.add_hparams(hparam_dict, metric_dict)

        if self.args.distributed:
            dist.barrier()

        self.optim.zero_grad()

        for epoch in range(self.args.epochs):
            self.model.train()
            if self.args.distributed:
                self.train_loader.sampler.set_epoch(epoch)
            if self.verbose:
                pbar = tqdm(total=len(self.train_loader), ncols=150)

                results = np.zeros(4, dtype=int)
                quesid2ans = {}

            for step_i, batch in enumerate(self.train_loader):
                vis_feats = batch['vis_feats'].cuda()
                boxes = batch['boxes'].cuda()

                ques_id = batch['question_ids']
                B = len(ques_id)

                input_ids = batch['word_ids'].cuda()
                input_ids = input_ids.unsqueeze(1).repeat(1, 2,
                                                          1).view(B * 2, -1)
                label = batch['labels'].cuda()

                results = self.model(
                    input_ids=input_ids,
                    visual_feats=vis_feats,
                    visual_pos=boxes,
                    attention_mask=input_ids > 0,
                )

                logit = results['logit']

                loss = self.mce_loss(logit, label)

                loss.backward()

                update = True
                if self.args.update_freq > 1:
                    if step_i == 0:
                        update = False
                    elif step_i % self.args.update_freq == 0 or step_i == len(
                            self.train_loader) - 1:
                        update = True
                    else:
                        update = False

                if update:
                    if not self.args.no_clip_grad:
                        nn.utils.clip_grad_norm_(self.model.parameters(),
                                                 self.args.clip_grad_norm)

                    self.optim.step()
                    self.lr_scheduler.step()
                    for param in self.model.parameters():
                        param.grad = None

                try:
                    lr = self.scheduler.get_last_lr()[0]
                except AttributeError:
                    lr = self.args.lr

                if self.verbose:
                    loss_meter.update(loss.item())
                    desc_str = f'Epoch {epoch} | LR {lr:.6f} | '
                    desc_str += f'Loss {loss_meter.val:4f} |'

                    score, predict = logit.max(1)
                    predict = predict.cpu().numpy()
                    label = label.cpu().numpy()

                    for qid, pred in zip(ques_id, predict):
                        quesid2ans[qid] = pred

                    results[0] += sum((label == 1) & (predict == 1))
                    results[1] += sum((label == 1) & (predict == 0))
                    results[2] += sum((label == 0) & (predict == 1))
                    results[3] += sum((label == 0) & (predict == 0))
                    n_total = sum(results)

                    desc_str += f' TP {results[0]} ({results[0]/n_total*100:.1f}%)'
                    desc_str += f' FN {results[1]} ({results[1]/n_total*100:.1f}%)'
                    desc_str += f' FP {results[2]} ({results[2]/n_total*100:.1f}%)'
                    desc_str += f' TN {results[3]} ({results[3]/n_total*100:.1f}%)'

                    pbar.set_description(desc_str)
                    pbar.update(1)

                if self.args.distributed:
                    dist.barrier()

            if self.verbose:
                pbar.close()
                score = self.train_loader.evaluator.evaluate(quesid2ans) * 100.
                log_str = "\nEpoch %d: Train %0.2f" % (epoch, score)

                if not self.args.dry:
                    self.writer.add_scalar(f'NLVR/Train/score', score, epoch)

                # Validation
                valid_score = self.evaluate(self.val_loader) * 100.
                if valid_score > best_valid:
                    best_valid = valid_score
                    self.save("BEST")

                log_str += "\nEpoch %d: Valid %0.2f" % (epoch, valid_score)
                log_str += "\nEpoch %d: Best %0.2f\n" % (epoch, best_valid)

                if not self.args.dry:
                    self.writer.add_scalar(f'NLVR/Valid/score', valid_score,
                                           epoch)

                print(log_str)
                self.logger.info(log_str)

            if self.args.distributed:
                dist.barrier()

        if self.verbose:
            self.save("LAST")
Ejemplo n.º 12
0
Archivo: nlvr.py Proyecto: j-min/VL-T5
    def train(self):
        if self.verbose:
            loss_meter = LossMeter()
            # best_eval_loss = 9595.
            quesid2ans = {}
            best_valid = 0.
            best_epoch = 0

            if 't5' in self.args.backbone:
                project_name = "VLT5_NLVR"
            elif 'bart' in self.args.backbone:
                project_name = "VLBART_NLVR"

            wandb.init(project=project_name)
            wandb.run.name = self.args.run_name
            wandb.config.update(self.args)
            wandb.watch(self.model)

            src_dir = Path(__file__).resolve().parent
            base_path = str(src_dir.parent)
            src_dir = str(src_dir)
            wandb.save(os.path.join(src_dir + "/*.py"), base_path=base_path)

        if self.args.distributed:
            dist.barrier()

        global_step = 0
        for epoch in range(self.args.epochs):
            if self.start_epoch is not None:
                epoch += self.start_epoch
            self.model.train()
            if self.args.distributed:
                self.train_loader.sampler.set_epoch(epoch)
            if self.verbose:
                pbar = tqdm(total=len(self.train_loader), ncols=150)
                nlvr_results = np.zeros(4, dtype=int)


            epoch_results = {
                'loss': 0,

            }

            quesid2ans = {}
            train_acc = 0.
            train_acc_steps = int(len(self.train_loader) * 0.05)
            last_acc_step = 0


            for step_i, batch in enumerate(self.train_loader):

                self.optim.zero_grad()
                if self.args.fp16 and _use_native_amp:
                    with autocast():
                        if self.args.distributed:
                            results = self.model.module.train_step(batch)
                        else:
                            results = self.model.train_step(batch)
                else:
                    if self.args.distributed:
                        results = self.model.module.train_step(batch)
                    else:
                        results = self.model.train_step(batch)

                loss = results['loss']

                # print(f'GPU{self.args.gpu} after loss')

                if self.args.fp16 and _use_native_amp:
                    self.scaler.scale(loss).backward()
                elif self.args.fp16 and _use_apex:
                    with amp.scale_loss(loss, self.optim) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()

                # print(f'GPU{self.args.gpu} after backward')

                loss = loss.detach()

                # Update Parameters
                if self.args.clip_grad_norm > 0:
                    if self.args.fp16 and _use_native_amp:
                        self.scaler.unscale_(self.optim)
                        torch.nn.utils.clip_grad_norm_(
                            self.model.parameters(), self.args.clip_grad_norm)
                    elif self.args.fp16 and _use_apex:
                        torch.nn.utils.clip_grad_norm_(amp.master_params(
                            self.optim), self.args.clip_grad_norm)
                    else:
                        torch.nn.utils.clip_grad_norm_(
                            self.model.parameters(), self.args.clip_grad_norm)

                if self.args.fp16 and _use_native_amp:
                    self.scaler.step(self.optim)
                    self.scaler.update()
                else:
                    self.optim.step()

                if self.lr_scheduler:
                    self.lr_scheduler.step()
                # self.model.zero_grad()
                for param in self.model.parameters():
                    param.grad = None

                global_step += 1

                for k, v in results.items():
                    if k in epoch_results:
                        epoch_results[k] += v.item()

                if self.lr_scheduler:
                    if version.parse(torch.__version__) >= version.parse("1.4"):
                        lr = self.lr_scheduler.get_last_lr()[0]
                    else:
                        lr = self.lr_scheduler.get_lr()[0]
                else:
                    try:
                        lr = self.optim.get_lr()[0]
                    except AttributeError:
                        lr = self.args.lr

                if self.verbose:
                    loss_meter.update(loss.item())
                    desc_str = f'Epoch {epoch} | LR {lr:.6f} | '
                    desc_str += f'Loss {loss_meter.val:4f} |'

                    pred_ans = results['pred_ans_id']
                    ques_ids = batch['question_ids']

                    for qid, ans in zip(ques_ids, pred_ans):
                        quesid2ans[qid] = ans

                    label = batch['labels'].cpu().numpy()
                    predict = results['pred_ans_id']
                    nlvr_results[0] += sum((label == 1) & (predict == 1))
                    nlvr_results[1] += sum((label == 1) & (predict == 0))
                    nlvr_results[2] += sum((label == 0) & (predict == 1))
                    nlvr_results[3] += sum((label == 0) & (predict == 0))
                    n_total = sum(nlvr_results)

                    desc_str += f' TP {nlvr_results[0]} ({nlvr_results[0]/n_total*100:.1f}%)'
                    desc_str += f' FN {nlvr_results[1]} ({nlvr_results[1]/n_total*100:.1f}%)'
                    desc_str += f' FP {nlvr_results[2]} ({nlvr_results[2]/n_total*100:.1f}%)'
                    desc_str += f' TN {nlvr_results[3]} ({nlvr_results[3]/n_total*100:.1f}%)'

                    pbar.set_description(desc_str)
                    pbar.update(1)

                if self.args.distributed:
                    dist.barrier()

            if self.verbose:
                pbar.close()

                log_str = ''

                # train_score_dict = self.train_loader.evaluator.evaluate(quesid2ans)
                # train_acc = train_score_dict['accuracy']  * 100.
                # train_cons = train_score_dict['consistency'] * 100.

                train_acc = self.train_loader.evaluator.evaluate_train(quesid2ans) * 100.

                train_score = train_acc

                log_str += "\nEpoch %d: Train %0.2f" % (epoch, train_score)

                # Validation
                valid_score_dict = self.evaluate(self.val_loader)
                valid_acc = valid_score_dict['accuracy'] * 100.
                # valid_cons = valid_score_dict['consistency'] * 100.

                valid_score = valid_acc

                if valid_score > best_valid:
                    best_valid = valid_score
                    best_epoch = epoch
                    self.save("BEST")

                log_str += "\nEpoch %d: Valid %0.2f" % (epoch, valid_score)
                log_str += "\nEpoch %d: Best %0.2f\n" % (best_epoch, best_valid)

                wandb_log_dict = {}
                # wandb_log_dict['Train/Loss'] = loss_meter.val
                # wandb_log_dict['Train/score'] = score
                # wandb_log_dict['Valid/score'] = valid_score

                # for score_name, score in train_score_dict.items():
                    # wandb_log_dict[f'Train/{score_name}'] = score * 100.
                wandb_log_dict['Train/accuracy'] = train_acc

                for score_name, score in valid_score_dict.items():
                    wandb_log_dict[f'Valid/{score_name}'] = score * 100.

                wandb.log(wandb_log_dict, step=epoch)

                print(log_str)

            if self.args.distributed:
                dist.barrier()


        if self.verbose:
            self.save("LAST")

            # Test Set
            best_path = os.path.join(self.args.output, 'BEST')
            self.load(best_path)

            log_str = 'Test set results\n'

            dump_path = os.path.join(self.args.output, 'submit.csv')
            test_score_dict = self.evaluate(self.test_loader, dump_path=dump_path)
            wandb.save(dump_path, base_path=self.args.output)

            wandb_log_dict = {}
            for score_name, score in test_score_dict.items():
                wandb_log_dict[f'Test/{score_name}'] = score * 100.
            wandb.log(wandb_log_dict, step=epoch)

            from pprint import pformat

            log_str += pformat(test_score_dict)

            print(log_str)


            wandb.log({'finished': True})
Ejemplo n.º 13
0
    def train(self):
        if self.verbose:
            vqa_loss_meter = LossMeter()
            refcoco_loss_meter = LossMeter()
            # best_eval_loss = 9595.
            quesid2ans = {}
            best_vqa_valid = 0.
            best_vqa_epoch = 0

            # gqa
            best_gqa_valid = 0
            best_gqa_epoch = 0

            # nlvr
            best_nlvr_valid = 0
            best_nlvr_epoch = 0

            # vcr
            best_valid_Q_AR = 0
            best_vcr_epoch = 0

            # refcoco
            best_refcoco_valid = 0
            best_refcoco_epoch = 0

            # caption
            best_caption_valid = 0
            best_caption_epoch = 0

            # mmt
            best_mmt_valid = 0
            best_mmt_epoch = 0

            assert 't5' in self.args.backbone
            self.setup_wandb()

        if self.args.distributed:
            dist.barrier()

        global_step = 0
        for epoch in range(self.args.epochs):
            if self.start_epoch is not None:
                epoch += self.start_epoch
            self.model.train()

            if self.args.distributed:
                self.train_loader.set_epoch(epoch)
            if self.verbose:
                pbar = tqdm(total=len(self.train_loader), ncols=250)

            epoch_results = {
                'loss': 0.,
            }

            task_counter = {
                'vqa': 0,
                'gqa': 0,
                'nlvr': 0,
                'refcoco': 0,
                'vcr': 0,
                'caption': 0,
                'mmt': 0,
            }

            # vqa
            quesid2ans = {}
            train_acc = 0.
            # train_acc_steps = int(len(self.train_loader) * 0.05)
            # last_acc_step = 0

            # refcoco
            n_correct = 0
            n_total = 0

            for step_i, batch in enumerate(self.train_loader):

                # print(f'GPU{self.args.gpu} inside training loop')
                # print(batch)
                task = batch['task']
                # if self.verbose:
                #     print('task', task)
                task_counter[task] += 1

                batch['log_train_accuracy'] = self.args.log_train_accuracy

                # self.optim.zero_grad()
                if self.args.fp16 and _use_native_amp:
                    with autocast():
                        if self.args.distributed:
                            results = self.model.module.train_step(batch)
                        else:
                            results = self.model.train_step(batch)
                else:
                    if self.args.distributed:
                        results = self.model.module.train_step(batch)
                    else:
                        results = self.model.train_step(batch)

                loss = results['loss']

                # print(f'GPU{self.args.gpu} after loss')

                if self.args.fp16 and _use_native_amp:
                    self.scaler.scale(loss).backward()
                elif self.args.fp16 and _use_apex:
                    with amp.scale_loss(loss, self.optim) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()

                # print(f'GPU{self.args.gpu} after backward')

                loss = loss.detach()

                # Update Parameters
                if self.args.clip_grad_norm > 0:
                    if self.args.fp16 and _use_native_amp:
                        self.scaler.unscale_(self.optim)
                        torch.nn.utils.clip_grad_norm_(
                            self.model.parameters(), self.args.clip_grad_norm)
                    elif self.args.fp16 and _use_apex:
                        torch.nn.utils.clip_grad_norm_(
                            amp.master_params(self.optim),
                            self.args.clip_grad_norm)
                    else:
                        torch.nn.utils.clip_grad_norm_(
                            self.model.parameters(), self.args.clip_grad_norm)

                if self.args.fp16 and _use_native_amp:
                    self.scaler.step(self.optim)
                    self.scaler.update()
                else:
                    self.optim.step()

                if self.lr_scheduler:
                    self.lr_scheduler.step()
                for param in self.model.parameters():
                    param.grad = None

                global_step += 1

                for k, v in results.items():
                    if k in epoch_results:
                        epoch_results[k] += v.item()

                if self.lr_scheduler:
                    if version.parse(
                            torch.__version__) >= version.parse("1.4"):
                        lr = self.lr_scheduler.get_last_lr()[0]
                    else:
                        lr = self.lr_scheduler.get_lr()[0]
                else:
                    try:
                        lr = self.optim.get_lr()[0]
                    except AttributeError:
                        lr = self.args.lr

                # self.train_step_post_hook(result)

                if self.args.log_train_accuracy and task == 'refcoco':
                    correct = results['correct']
                    n_correct += sum(correct)
                    n_total += len(correct)

                if self.verbose:
                    if task == 'vqa':
                        vqa_loss_meter.update(loss.item())
                    elif task == 'refcoco':
                        refcoco_loss_meter.update(loss.item())

                    desc_str = f'Epoch {epoch} | LR {lr:.6f}'

                    desc_str += f" |"
                    if 'vqa' in self.args.tasks:
                        desc_str += f" VQA {task_counter['vqa']}"
                    if 'gqa' in self.args.tasks:
                        desc_str += f" GQA {task_counter['gqa']}"
                    if 'nlvr' in self.args.tasks:
                        desc_str += f" NLVR {task_counter['nlvr']}"
                    if 'vcr' in self.args.tasks:
                        desc_str += f" VCR {task_counter['vcr']}"
                    if 'refcoco' in self.args.tasks:
                        desc_str += f" RefCOCOg {task_counter['refcoco']}"
                    if 'caption' in self.args.tasks:
                        desc_str += f" COCO {task_counter['caption']}"
                    if 'mmt' in self.args.tasks:
                        desc_str += f" MMT {task_counter['mmt']}"

                    if len(vqa_loss_meter) > 0:
                        desc_str += f' | VQA Loss {vqa_loss_meter.val:4f}'
                    if len(refcoco_loss_meter) > 0:
                        desc_str += f' | RefCOCOg Loss {refcoco_loss_meter.val:.3f}'

                    if self.args.log_train_accuracy and n_total > 0:
                        desc_str += f' | RefCOCOg Acc'
                        desc_str += f' Correct {n_correct:.0f}'
                        desc_str += f' (Acc {n_correct/n_total*100:.1f}%)'

                    pbar.set_description(desc_str)
                    pbar.update(1)

                if self.args.distributed:
                    dist.barrier()

            if self.verbose:
                pbar.close()
                self.save("Epoch%02d" % (epoch + 1))

            if self.args.log_train_accuracy:
                train_score_dict = {'n_correct': n_correct, 'n_total': n_total}
                train_score_dict = reduce_dict(train_score_dict, self.args.gpu)

            if self.verbose:
                # Validation
                log_str = ''
                wandb_log_dict = {}

                if 'vqa' in self.args.tasks:
                    # VQA
                    vqa_val_loader = self.val_loader['vqa']
                    score_dict = self.vqa_evaluate(vqa_val_loader)
                    valid_score = score_dict['topk_score'] * 100.
                    valid_score_raw = score_dict['overall']
                    if valid_score_raw > best_vqa_valid or epoch == 0:
                        best_vqa_valid = valid_score_raw
                        best_vqa_epoch = epoch
                        # self.save("VQA_BEST")
                    log_str += f"VQA"
                    log_str += "\nEpoch %d: Valid Raw %0.2f Topk %0.2f" % (
                        epoch, valid_score_raw, valid_score)
                    log_str += "\nEpoch %d: Best Raw %0.2f\n" % (
                        best_vqa_epoch, best_vqa_valid)
                    wandb_log_dict['VQA/Valid/score'] = valid_score
                    wandb_log_dict['VQA/Valid/raw_score'] = score_dict[
                        'overall']
                if 'gqa' in self.args.tasks:
                    # GQA
                    gqa_val_loader = self.val_loader['gqa']
                    valid_score = self.gqa_evaluate(gqa_val_loader) * 100
                    if valid_score > best_gqa_valid or epoch == 0:
                        best_gqa_valid = valid_score
                        best_gqa_epoch = epoch
                    wandb_log_dict['GQA/Valid/Acc'] = valid_score
                    log_str += f"GQA"
                    log_str += "\nEpoch %d: Valid %0.2f" % (epoch, valid_score)
                    log_str += "\nEpoch %d: Best %0.2f\n" % (best_gqa_epoch,
                                                             best_gqa_valid)
                if 'nlvr' in self.args.tasks:
                    # NLVR
                    nlvr_val_loader = self.val_loader['nlvr']
                    valid_score_dict = self.nlvr_evaluate(nlvr_val_loader)
                    valid_acc = valid_score_dict['accuracy'] * 100.
                    if valid_acc > best_nlvr_valid or epoch == 0:
                        best_nlvr_valid = valid_acc
                        best_nlvr_epoch = epoch
                    wandb_log_dict['NLVR/Valid/Acc'] = valid_acc
                    log_str += f"NLVR"
                    log_str += "\nEpoch %d: Valid %0.2f" % (epoch, valid_acc)
                    log_str += "\nEpoch %d: Best %0.2f\n" % (best_nlvr_epoch,
                                                             best_nlvr_valid)

                if 'vcr' in self.args.tasks:
                    # VCR
                    vcr_val_loader = self.val_loader['vcr']
                    valid_score_dict = self.vcr_evaluate(vcr_val_loader)
                    valid_Q_A = valid_score_dict['Q_A'] / valid_score_dict[
                        'n_total'] * 100
                    valid_QA_R = valid_score_dict['QA_R'] / valid_score_dict[
                        'n_total'] * 100
                    valid_Q_AR = valid_score_dict['Q_AR'] / valid_score_dict[
                        'n_total'] * 100
                    valid_n_total = int(valid_score_dict['n_total'])
                    if valid_Q_AR > best_valid_Q_AR or epoch == 0:
                        best_valid_Q_AR = valid_Q_AR
                        best_vcr_epoch = epoch
                    wandb_log_dict['VCR/Valid/Q_A'] = valid_Q_A
                    wandb_log_dict['VCR/Valid/QA_R'] = valid_QA_R
                    wandb_log_dict['VCR/Valid/Q_AR'] = valid_Q_AR
                    log_str += f"VCR"
                    log_str += "\nEpoch %d: Valid %0.2f" % (epoch, valid_Q_AR)
                    log_str += "\nEpoch %d: Best %0.2f\n" % (best_vcr_epoch,
                                                             best_valid_Q_AR)

                if 'refcoco' in self.args.tasks:
                    # RefCOCO
                    refcoco_val_loader = self.val_loader['refcoco']
                    if self.args.log_train_accuracy:
                        train_acc = train_score_dict[
                            'n_correct'] / train_score_dict['n_total'] * 100
                        train_n_correct = int(train_score_dict['n_correct'])
                        train_n_total = int(train_score_dict['n_total'])
                    valid_score_dict = self.refcoco_evaluate(
                        refcoco_val_loader)
                    valid_acc = valid_score_dict[
                        'n_correct'] / valid_score_dict['n_total'] * 100
                    valid_n_correct = int(valid_score_dict['n_correct'])
                    valid_n_total = int(valid_score_dict['n_total'])
                    if valid_acc > best_refcoco_valid or epoch == 0:
                        best_refcoco_valid = valid_acc
                        best_refcoco_epoch = epoch
                    if self.args.log_train_accuracy:
                        wandb_log_dict['RefCOCO/Train/Acc'] = train_acc
                    wandb_log_dict['RefCOCO/Valid/Acc'] = valid_acc
                    log_str += f"RefCOCOg"
                    if self.args.log_train_accuracy:
                        log_str += f"\nEpoch {epoch}: Train"
                        log_str += f" Acc {train_acc:.2f}% |"
                        log_str += f" # correct {train_n_correct} # total {train_n_total}"
                    log_str += f"\nEpoch {epoch}: Valid"
                    log_str += f" Acc {valid_acc:.2f}% |"
                    log_str += f" # correct {valid_n_correct} # total {valid_n_total}"
                    log_str += f"\nEpoch {best_refcoco_epoch}: Best Acc {best_refcoco_valid:.2f}%\n"

                if 'caption' in self.args.tasks:
                    # COCO Caption
                    caption_val_loader = self.val_loader['caption']
                    valid_results = self.caption_evaluate(caption_val_loader)
                    valid_score = valid_results['CIDEr'] * 100
                    if valid_score > best_caption_valid or epoch == 0:
                        best_caption_valid = valid_score
                        best_caption_epoch = epoch
                    for score_name, score in valid_results.items():
                        wandb_log_dict[
                            f'Caption/Valid/{score_name}'] = score * 100
                    log_str += f"COCO Caption"
                    log_str += "\nEpoch %d: Valid CIDEr %0.2f" % (epoch,
                                                                  valid_score)
                    log_str += "\nEpoch %d: Best %0.2f\n" % (
                        best_caption_epoch, best_caption_valid)

                if 'mmt' in self.args.tasks:
                    # MMT
                    mmt_val_loader = self.val_loader['mmt']
                    valid_results = self.mmt_evaluate(mmt_val_loader)
                    valid_score = valid_results['BLEU']
                    if valid_score > best_mmt_valid:
                        best_mmt_valid = valid_score
                        best_mmt_epoch = epoch
                    for score_name, score in valid_results.items():
                        wandb_log_dict[f'MMT/Valid/{score_name}'] = score
                    log_str += f"Multi30K En-De"
                    log_str += "\nEpoch %d: Valid BLEU %0.2f" % (epoch,
                                                                 valid_score)
                    log_str += "\nEpoch %d: Best %0.2f\n" % (best_mmt_epoch,
                                                             best_mmt_valid)

                wandb.log(wandb_log_dict, step=epoch)

                print(log_str)

            if self.args.distributed:
                dist.barrier()

        # Test Set
        if self.verbose:
            self.save("LAST")

            log_str = ''
            wandb_log_dict = {}

            if 'vqa' in self.args.tasks:
                # VQA
                vqa_test_loader = self.test_loader['vqa']
                evaluator = vqa_test_loader.evaluator
                dump_path = os.path.join(self.args.output,
                                         'karpathy_test_predict.json')
                quesid2ans = self.vqa_predict(vqa_test_loader, dump_path)
                wandb.save(dump_path, base_path=self.args.output)

                acc_dict_all = evaluator.evaluate_raw(quesid2ans)
                acc_dict_answerable = evaluator.evaluate_raw(
                    quesid2ans, is_topk_optimal=True)
                acc_dict_unanswerable = evaluator.evaluate_raw(
                    quesid2ans, is_topk_optimal=False)

                wandb_log_dict['VQA/Test/overall'] = acc_dict_all['overall']
                wandb_log_dict['VQA/Test/topk_optimal'] = acc_dict_answerable[
                    'overall']
                wandb_log_dict[
                    'VQA/Test/topk_not_optimal'] = acc_dict_unanswerable[
                        'overall']

                vqa_submit_test_loader = self.test_loader['vqa_submit']
                dump_path = os.path.join(self.args.output, 'vqa_submit.json')
                self.vqa_predict(vqa_submit_test_loader, dump_path=dump_path)
                wandb.save(dump_path, base_path=self.args.output)

            if 'nlvr' in self.args.tasks:
                # NLVR
                nlvr_test_loader = self.test_loader['nlvr']
                dump_path = os.path.join(self.args.output, 'nlvr_submit.csv')
                test_score_dict = self.nlvr_evaluate(nlvr_test_loader,
                                                     dump_path=dump_path)
                wandb.save(dump_path, base_path=self.args.output)
                for score_name, score in test_score_dict.items():
                    wandb_log_dict[f'NLVR/Test/{score_name}'] = score * 100.
            if 'refcoco' in self.args.tasks:
                # RefCOCO
                refcoco_test_loader = self.test_loader['refcoco']
                test_score_dict = self.refcoco_evaluate(refcoco_test_loader)
                test_acc = test_score_dict['n_correct'] / test_score_dict[
                    'n_total'] * 100
                test_n_correct = int(test_score_dict['n_correct'])
                test_n_total = int(test_score_dict['n_total'])
                wandb_log_dict['RefCOCO/test/Acc'] = test_acc
                log_str = 'RefCOCOg'
                log_str += f"\nTest Acc {test_acc:.2f}%"
                log_str += f"\nTest # correct {test_n_correct} # total {test_n_total}"
            if 'caption' in self.args.tasks:
                # COCO Caption
                caption_test_loader = self.test_loader['caption']
                test_results = self.caption_evaluate(caption_test_loader)
                for score_name, score in test_results.items():
                    wandb_log_dict[f'Caption/Test/{score_name}'] = score

            if 'mmt' in self.args.tasks:
                # MMT
                mmt_test2016_loader = self.test_loader['mmt_test2016']
                mmt_test2017_loader = self.test_loader['mmt_test2017']
                mmt_test2018_loader = self.test_loader['mmt_test2018']
                for loader in [
                        mmt_test2016_loader, mmt_test2017_loader,
                        mmt_test2018_loader
                ]:
                    split = loader.dataset.source
                    dump_path = os.path.join(self.args.output,
                                             f'submit_{split}_raw.txt')
                    test_results = self.mmt_evaluate(loader,
                                                     dump_path=dump_path)
                    for score_name, score in test_results.items():
                        wandb_log_dict[f'MMT/{split}/{score_name}'] = score
                    log_str += f'{split} set results\n'
                    log_str += pformat(test_results)

            print(log_str)
            wandb.log(wandb_log_dict, step=self.args.epochs)

            wandb.log({'finished': True})

        if self.args.distributed:
            dist.barrier()
            exit()
Ejemplo n.º 14
0
    def train(self):
        if self.verbose:
            loss_meter = LossMeter()
            best_valid = 0.
            best_epoch = 0

            if not self.wandb_initialized:

                if 't5' in self.args.backbone:
                    project_name = "VLT5_COCOCaption"
                elif 'bart' in self.args.backbone:
                    project_name = "VLBart_COCOCaption"

                wandb.init(project=project_name)
                wandb.run.name = self.args.run_name
                wandb.config.update(self.args)
                wandb.watch(self.model)

                src_dir = Path(__file__).resolve().parent
                base_path = str(src_dir.parent)
                src_dir = str(src_dir)
                wandb.save(os.path.join(src_dir + "/*.py"),
                           base_path=base_path)

                self.wandb_initialized = True

        if self.args.distributed:
            dist.barrier()

        global_step = 0
        epochs = self.args.epochs

        for epoch in range(epochs):

            if self.start_epoch is not None:
                epoch += self.start_epoch
            self.model.train()
            if self.args.distributed:
                self.train_loader.sampler.set_epoch(epoch)
            if self.verbose:
                pbar = tqdm(total=len(self.train_loader), ncols=120)

            epoch_results = {
                'loss': 0.,
            }

            for step_i, batch in enumerate(self.train_loader):

                if self.args.fp16 and _use_native_amp:
                    with autocast():
                        if self.args.distributed:
                            results = self.model.module.train_step(batch)
                        else:
                            results = self.model.train_step(batch)
                else:
                    if self.args.distributed:
                        results = self.model.module.train_step(batch)
                    else:
                        results = self.model.train_step(batch)

                loss = results['loss']

                if self.args.fp16 and _use_native_amp:
                    self.scaler.scale(loss).backward()
                elif self.args.fp16 and _use_apex:
                    with amp.scale_loss(loss, self.optim) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()

                loss = loss.detach()

                # Update Parameters
                if self.args.clip_grad_norm > 0:
                    if self.args.fp16 and _use_native_amp:
                        self.scaler.unscale_(self.optim)
                        torch.nn.utils.clip_grad_norm_(
                            self.model.parameters(), self.args.clip_grad_norm)
                    elif self.args.fp16 and _use_apex:
                        torch.nn.utils.clip_grad_norm_(
                            amp.master_params(self.optim),
                            self.args.clip_grad_norm)
                    else:
                        torch.nn.utils.clip_grad_norm_(
                            self.model.parameters(), self.args.clip_grad_norm)

                update = True
                if self.args.gradient_accumulation_steps > 1:
                    if step_i == 0:
                        update = False
                    elif step_i % self.args.gradient_accumulation_steps == 0 or step_i == len(
                            self.train_loader) - 1:
                        update = True
                    else:
                        update = False

                if update:
                    if self.args.fp16 and _use_native_amp:
                        self.scaler.step(self.optim)
                        self.scaler.update()
                    else:
                        self.optim.step()

                    if self.lr_scheduler:
                        self.lr_scheduler.step()
                    # self.model.zero_grad()
                    for param in self.model.parameters():
                        param.grad = None
                    global_step += 1

                for k, v in results.items():
                    if k in epoch_results:
                        epoch_results[k] += v.item()

                if self.lr_scheduler:
                    if version.parse(
                            torch.__version__) >= version.parse("1.4"):
                        lr = self.lr_scheduler.get_last_lr()[0]
                    else:
                        lr = self.lr_scheduler.get_lr()[0]
                else:
                    try:
                        lr = self.optim.get_lr()[0]
                    except AttributeError:
                        lr = self.args.lr

                if self.verbose:
                    loss_meter.update(loss.item())
                    desc_str = f'Epoch {epoch} | LR {lr:.6f} | Steps {global_step}'
                    desc_str += f' | Loss {loss_meter.val:4f}'
                    pbar.set_description(desc_str)
                    pbar.update(1)

            if self.args.distributed:
                dist.barrier()

            if self.verbose:
                pbar.close()

                # format ex)
                # {'Bleu_1': 0.9999999997500004,
                #  'Bleu_2': 0.5773502690332603,
                #  'Bleu_3': 4.3679023223468616e-06,
                #  'Bleu_4': 1.4287202142987477e-08,
                #  'CIDEr': 3.333333333333333,
                #  'METEOR': 0.43354749322305886,
                #  'ROUGE_L': 0.75,
                #  'SPICE': 0.6666666666666666}

            # Validation
            valid_results = self.evaluate(self.val_loader)

            if self.verbose:
                valid_score = valid_results['CIDEr']

                if valid_score > best_valid or epoch == 0:
                    best_valid = valid_score
                    best_epoch = epoch
                    self.save("BEST")

                log_str = ''

                log_str += pformat(valid_results)
                log_str += "\nEpoch %d: Valid CIDEr %0.4f" % (epoch,
                                                              valid_score)
                log_str += "\nEpoch %d: Best CIDEr %0.4f\n" % (best_epoch,
                                                               best_valid)

                wandb_log_dict = {}
                wandb_log_dict['Train/Loss'] = epoch_results['loss'] / len(
                    self.train_loader)

                for score_name, score in valid_results.items():
                    wandb_log_dict[f'Valid/{score_name}'] = score

                wandb_log_dict[f'Valid/best_epoch'] = best_epoch

                wandb.log(wandb_log_dict, step=epoch)

                print(log_str)

            if self.args.distributed:
                dist.barrier()

        if self.verbose:
            self.save("LAST")

        # Test Set
        best_path = os.path.join(self.args.output, 'BEST')
        self.load(best_path)

        if self.verbose:
            wandb.save(best_path, base_path=self.args.output)
            print(f'\nUploaded checkpoint {best_epoch}', best_path)

        test_results = self.evaluate(self.test_loader)

        if self.verbose:
            wandb_log_dict = {}
            for score_name, score in test_results.items():
                wandb_log_dict[f'Test/{score_name}'] = score
            wandb.log(wandb_log_dict, step=epoch)

            log_str = 'Test set results\n'
            log_str += pformat(test_results)

            print(log_str)

        if self.args.distributed:
            dist.barrier()
Ejemplo n.º 15
0
    def train(self):
        LOSSES_NAME = self.args.LOSSES_NAME
        task_dict = {
            'Mask_LM': 'word_mask',
            'Matched': 'matched',
            'Mask_Obj': 'vis_mask',
            'Mask_Attr': 'vis_mask',
            'Mask_Feat': 'vis_mask',
            'QA': 'qa'
        }

        if self.args.dry:
            results = self.evaluate_epoch(epoch=0)

        self.optim.zero_grad()

        if self.verbose:
            loss_meters = [LossMeter() for _ in range(len(LOSSES_NAME))]
            best_eval_loss = 9595.

            from torch.utils.tensorboard import SummaryWriter
            self.writer = SummaryWriter(log_dir=self.args.log_dir)
            print('logging at', str(self.args.log_dir))
            self.logger.info('logging at' + str(self.args.log_dir))

            hparam_dict = {}
            for k, v in self.args.__dict__.items():
                if type(v) in [int, float, str, bool, torch.Tensor]:
                    hparam_dict[k] = v
            metric_dict = {}

            self.writer.add_hparams(hparam_dict, metric_dict)

        dist.barrier()

        n_update = 0
        global_step = 0
        for epoch in range(self.args.epochs):
            if self.start_epoch is not None:
                epoch += self.start_epoch
            if self.args.distributed:
                self.train_loader.sampler.set_epoch(epoch)

            # Train
            self.model.train()
            loss_counts = [0 for _ in range(len(LOSSES_NAME))]

            if self.verbose:
                pbar = tqdm(total=len(self.train_loader), ncols=240)

            epoch_results = {
                'lm_loss': 0,
                'vis_loss': 0,
                'matched_loss': 0,
                'qa_loss': 0,
                'obj_loss': 0,
                'feat_loss': 0,
                'attr_loss': 0,
            }
            for k in list(epoch_results.keys()):
                if k[-4:] == 'loss':
                    epoch_results[f'{k}_count'] = 0

            if self.args.task_qa:
                uid2ans = {}

            for step_i, batch in enumerate(self.train_loader):
                # task = random.choice(self.args.MASK_MODALITY)
                task_i = step_i % len(self.args.MASK_MODALITY)
                task = self.args.MASK_MODALITY[task_i]

                # with torch.autograd.set_detect_anomaly(True):
                results = self.forward(batch, task)

                if self.args.fp16 and _use_native_amp:
                    with autocast():
                        results = self.model(batch, task)
                else:
                    results = self.model(batch, task)

                if task == 'vis_mask':
                    if 'Mask_Obj' in LOSSES_NAME:
                        epoch_results['obj_loss_count'] += 1
                    if 'Mask_Feat' in LOSSES_NAME:
                        epoch_results['feat_loss_count'] += 1
                    if 'Mask_Attr' in LOSSES_NAME:
                        epoch_results['attr_loss_count'] += 1
                    epoch_results['vis_loss_count'] += 1
                elif task == 'word_mask':
                    epoch_results['lm_loss_count'] += 1
                elif task == 'matched':
                    epoch_results['matched_loss_count'] += 1

                if self.args.task_qa:
                    epoch_results['qa_loss_count'] += 1
                    qa_pred = results['qa_pred']
                    for uid, ans_id in zip(batch['uid'],
                                           qa_pred.cpu().numpy()):
                        ans = self.train_loader.dataset.answer_table.id2ans(
                            ans_id)
                        uid2ans[uid] = ans

                loss = results['total_loss']

                #===== Update =====#
                if self.args.fp16 and _use_native_amp:
                    self.scaler.scale(loss).backward()
                elif self.args.fp16 and _use_apex:
                    with amp.scale_loss(loss, self.optim) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()

                loss = loss.detach()

                # Update Parameters
                if self.args.clip_grad_norm > 0:
                    if self.args.fp16 and _use_native_amp:
                        self.scaler.unscale_(self.optim)
                        torch.nn.utils.clip_grad_norm_(
                            self.model.parameters(), self.args.clip_grad_norm)
                    elif self.args.fp16 and _use_apex:
                        torch.nn.utils.clip_grad_norm_(
                            amp.master_params(self.optim),
                            self.args.clip_grad_norm)
                    else:
                        torch.nn.utils.clip_grad_norm_(
                            self.model.parameters(), self.args.clip_grad_norm)

                if self.args.fp16 and _use_native_amp:
                    self.scaler.step(self.optim)
                    self.scaler.update()
                else:
                    self.optim.step()

                if self.lr_scheduler:
                    self.lr_scheduler.step()
                for param in self.model.parameters():
                    param.grad = None

                global_step += 1
                #====================#

                try:
                    lr = self.scheduler.get_last_lr()[0]
                except AttributeError:
                    lr = self.args.lr

                if self.verbose:
                    desc_str = f'Epoch {epoch} | LR {lr:.6f} | '
                    if self.args.word_mask_predict:
                        desc_str += f'Word Mask: Uniform (MP) | '
                    elif self.args.word_mask_rate > 0:
                        desc_str += f'Word Mask: {self.args.word_mask_rate:.2f} | '

                    if self.args.vis_mask_predict:
                        desc_str += f'Vis Mask: Uniform (MP) |'
                    else:
                        desc_str += f'Vis Mask: {self.args.obj_mask_rate:.2f} |'

                    if self.args.task_qa:
                        loss_meter = loss_meters[-1]
                        loss_meter.update(results['qa_loss'].item())
                        loss_counts[-1] += 1

                    for i, (loss_name, loss_meter) in enumerate(
                            zip(LOSSES_NAME, loss_meters)):
                        if task_dict[loss_name] == task:
                            if task == 'vis_mask':
                                if loss_name == 'Mask_Obj':
                                    loss_meter.update(
                                        results['obj_loss'].item())
                                elif loss_name == 'Mask_Attr':
                                    loss_meter.update(
                                        results['attr_loss'].item())
                                elif loss_name == 'Mask_Feat':
                                    loss_meter.update(
                                        results['feat_loss'].item())
                            elif task == 'word_mask':
                                loss_meter.update(results['lm_loss'].item())
                            elif task == 'matched':
                                loss_meter.update(
                                    results['matched_loss'].item())
                            # elif task == 'qa':
                            #     loss_meter.update(results['qa_loss'].item())

                            loss_counts[i] += 1
                        if len(loss_meter) > 0:
                            loss_count = loss_counts[i]
                            if loss_name in [
                                    'Mask_LM', 'Matched', 'Mask_Obj',
                                    'Mask_Attr', 'Mask_Feat', 'QA'
                            ]:
                                desc_str += f' {loss_name} ({loss_count}) {loss_meter.val:.3f}'
                            else:
                                desc_str += f' {loss_name} {loss_meter.val:.3f}'

                            if step_i % 10 == 0:
                                self.writer.add_scalar(
                                    f'Train_steps/{loss_name}', loss_meter.val,
                                    global_step)

                    # if update:
                    n_update += 1
                    desc_str += f' | Total Update: {n_update}'

                    pbar.set_description(desc_str)
                    pbar.update(1)

            if self.verbose:
                pbar.close()

            dist.barrier()

            results = reduce_dict(epoch_results, self.args.gpu)
            if self.args.gpu == 0:
                total_loss = results['lm_loss'] + results[
                    'vis_loss'] + results['matched_loss'] + results['qa_loss']
                total_count = results['lm_loss_count'] + results[
                    'vis_loss_count'] + results['matched_loss_count']
                # + results['qa_loss_count']

                avg_train_loss = total_loss / total_count
                losses_str = f"Train Loss: {avg_train_loss:.4f}\n"

                for name, loss in results.items():
                    if name[-4:] == 'loss':
                        loss_count = int(results[name + '_count'])
                        if loss_count > 0:
                            avg_loss = loss / loss_count
                            if name == 'lm_loss':
                                name = 'Mask_LM'
                            elif name == 'matched_loss':
                                name = 'Matched'
                            elif name == 'obj_loss':
                                name = 'Mask_Obj'
                            elif name == 'attr_loss':
                                name = 'Mask_Attr'
                            elif name == 'feat_loss':
                                name = 'Mask_Feat'
                            elif name == 'qa_loss':
                                name = 'QA'
                            losses_str += f"{name} ({loss_count}): {avg_loss:.4f} "
                            self.writer.add_scalar(f'Train Loss/{name}',
                                                   avg_loss, epoch)
                losses_str += '\n'
                print(losses_str)
                self.logger.info(losses_str)

            if self.args.task_qa:
                dset2score, dset2cnt, score, cnt = self.train_loader.dataset.evaluator.evaluate(
                    uid2ans)

                dset2score = reduce_dict(dset2score, self.args.gpu)
                dset2cnt = reduce_dict(dset2cnt, self.args.gpu)
                score_cnt_dict = reduce_dict({
                    'score': score,
                    'cnt': cnt
                }, self.args.gpu)

                if self.args.gpu == 0:
                    score = score_cnt_dict['score']
                    cnt = score_cnt_dict['cnt']
                    accu = score / cnt
                    dset2accu = {}
                    for dset in dset2cnt:
                        dset2accu[dset] = dset2score[dset] / dset2cnt[dset]
                    accu_str = "Overall Accu %0.4f, " % (accu)
                    sorted_keys = sorted(dset2accu.keys())
                    for key in sorted_keys:
                        accu_str += "%s Accu %0.4f, " % (key, dset2accu[key])
                    print(accu_str)
                    self.logger.info(accu_str)

            dist.barrier()

            # Validation
            valid_results, valid_uid2ans = self.evaluate_epoch(epoch=epoch)

            valid_results = reduce_dict(valid_results, self.args.gpu)
            if self.args.gpu == 0:
                valid_total_loss = valid_results['lm_loss'] + valid_results[
                    'vis_loss'] + valid_results[
                        'matched_loss'] + valid_results['qa_loss']
                valid_total_count = valid_results[
                    'lm_loss_count'] + valid_results[
                        'vis_loss_count'] + valid_results['matched_loss_count']
                #  + valid_results['qa_loss_count']

                avg_valid_loss = valid_total_loss / valid_total_count
                losses_str = f"Valid Loss: {avg_valid_loss:.4f}\n"

                for name, loss in valid_results.items():
                    if name[-4:] == 'loss':
                        loss_count = int(valid_results[name + '_count'])
                        if loss_count > 0:
                            avg_loss = loss / loss_count
                            if name == 'lm_loss':
                                name = 'Mask_LM'
                            elif name == 'matched_loss':
                                name = 'Matched'
                            elif name == 'obj_loss':
                                name = 'Mask_Obj'
                            elif name == 'attr_loss':
                                name = 'Mask_Attr'
                            elif name == 'feat_loss':
                                name = 'Mask_Feat'
                            elif name == 'qa_loss':
                                name = 'QA'
                            losses_str += f"{name} ({loss_count}): {avg_loss:.4f} "
                            self.writer.add_scalar(f'Valid Loss/{name}',
                                                   avg_loss, epoch)

                losses_str += '\n'
                print(losses_str)
                self.logger.info(losses_str)

            if self.args.task_qa:
                dset2score, dset2cnt, score, cnt = self.val_loader.dataset.evaluator.evaluate(
                    valid_uid2ans)

                dset2score = reduce_dict(dset2score, self.args.gpu)
                dset2cnt = reduce_dict(dset2cnt, self.args.gpu)
                score_cnt_dict = reduce_dict({
                    'score': score,
                    'cnt': cnt
                }, self.args.gpu)

                if self.args.gpu == 0:
                    score = score_cnt_dict['score']
                    cnt = score_cnt_dict['cnt']
                    accu = score / cnt
                    dset2accu = {}
                    for dset in dset2cnt:
                        dset2accu[dset] = dset2score[dset] / dset2cnt[dset]
                    accu_str = "Overall Accu %0.4f, " % (accu)
                    sorted_keys = sorted(dset2accu.keys())
                    for key in sorted_keys:
                        accu_str += "%s Accu %0.4f, " % (key, dset2accu[key])
                    print(accu_str)
                    self.logger.info(accu_str)

            dist.barrier()

            if self.verbose:
                # Save
                if avg_valid_loss < best_eval_loss:
                    best_eval_loss = avg_valid_loss
                #     self.save("BEST_EVAL_LOSS")
                self.save("Epoch%02d" % (epoch + 1))

            dist.barrier()
Ejemplo n.º 16
0
    def train(self):
        LOSSES_NAME = self.args.LOSSES_NAME

        if self.args.dry:
            results = self.evaluate_epoch(epoch=0)

        if self.verbose:
            loss_meters = [LossMeter() for _ in range(len(LOSSES_NAME))]
            best_eval_loss = 9595.

            if 't5' in self.args.backbone:
                project_name = "VLT5_Pretrain"
            elif 'bart' in self.args.backbone:
                project_name = "VLBart_Pretrain"

            wandb.init(project=project_name)
            wandb.run.name = self.args.run_name
            wandb.config.update(self.args)
            wandb.watch(self.model)

            src_dir = Path(__file__).resolve().parent
            base_path = str(src_dir.parent)
            src_dir = str(src_dir)
            wandb.save(os.path.join(src_dir + "/*.py"), base_path=base_path)

        if self.args.distributed:
            dist.barrier()

        global_step = 0
        for epoch in range(self.args.epochs):
            if self.start_epoch is not None:
                epoch += self.start_epoch
            if self.args.distributed:
                self.train_loader.sampler.set_epoch(epoch)

            # Train
            self.model.train()

            if self.verbose:
                pbar = tqdm(total=len(self.train_loader), ncols=250)

            epoch_results = {}
            for loss_name in LOSSES_NAME:
                epoch_results[loss_name] = 0.
                epoch_results[f'{loss_name}_count'] = 0

            for step_i, batch in enumerate(self.train_loader):

                if self.args.fp16 and _use_native_amp:
                    with autocast():
                        if self.args.distributed:
                            results = self.model.module.train_step(batch)
                        else:
                            results = self.model.train_step(batch)
                else:
                    if self.args.distributed:
                        results = self.model.module.train_step(batch)
                    else:
                        results = self.model.train_step(batch)

                loss = results['loss']

                if self.args.fp16 and _use_native_amp:
                    self.scaler.scale(loss).backward()
                elif self.args.fp16 and _use_apex:
                    with amp.scale_loss(loss, self.optim) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()

                loss = loss.detach()

                # Update Parameters
                if self.args.clip_grad_norm > 0:
                    if self.args.fp16 and _use_native_amp:
                        self.scaler.unscale_(self.optim)
                        torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.clip_grad_norm)
                    elif self.args.fp16 and _use_apex:
                        torch.nn.utils.clip_grad_norm_(amp.master_params(self.optim), self.args.clip_grad_norm)
                    else:
                        torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.clip_grad_norm)

                if self.args.fp16 and _use_native_amp:
                    self.scaler.step(self.optim)
                    self.scaler.update()
                else:
                    self.optim.step()

                if self.lr_scheduler:
                    self.lr_scheduler.step()

                # self.model.zero_grad()
                for param in self.model.parameters():
                    param.grad = None

                global_step += 1

                if self.lr_scheduler:
                    if version.parse(torch.__version__) >= version.parse("1.4"):
                        lr = self.lr_scheduler.get_last_lr()[0]
                    else:
                        lr = self.lr_scheduler.get_lr()[0]
                else:
                    try:
                        lr = self.optim.get_lr()[0]
                    except AttributeError:
                        lr = self.args.lr

                for k, v in results.items():
                    if k in epoch_results:
                        if isinstance(v, int):
                            epoch_results[k] += v
                        elif isinstance(v, torch.Tensor):
                            epoch_results[k] += v.item()

                if self.verbose:
                    desc_str = f'Epoch {epoch} | LR {lr:.6f} |'

                    for i, (loss_name, loss_meter) in enumerate(zip(LOSSES_NAME, loss_meters)):

                        if loss_name in results:
                            loss_meter.update(results[f'{loss_name}'] / results[f'{loss_name}_count'])
                        if len(loss_meter) > 0:
                            loss_count = epoch_results[f'{loss_name}_count']
                            desc_str += f' {loss_name} ({loss_count}) {loss_meter.val:.3f}'

                    pbar.set_description(desc_str)
                    pbar.update(1)

            if self.verbose:
                pbar.close()

            dist.barrier()

            results = reduce_dict(epoch_results, average=False)
            if self.verbose:
                train_loss = results['total_loss']
                train_loss_count = results['total_loss_count']

                avg_train_loss = train_loss / train_loss_count
                losses_str = f"Train Loss: {avg_train_loss:.3f}\n"

                for name, loss in results.items():
                    if name[-4:] == 'loss':
                        loss_count = int(results[name+'_count'])
                        if loss_count > 0:
                            avg_loss = loss/loss_count
                            losses_str += f"{name} ({loss_count}): {avg_loss:.3f} "
                            wandb.log({f'Train Loss/{name}': avg_loss}, step=epoch)

                losses_str += '\n'
                print(losses_str)

            dist.barrier()

            # Validation
            valid_results, valid_uid2ans = self.evaluate_epoch(epoch=epoch)

            valid_results = reduce_dict(valid_results, average=False)
            if self.verbose:
                valid_loss = valid_results['total_loss']
                valid_loss_count = valid_results['total_loss_count']

                avg_valid_loss = valid_loss / valid_loss_count
                losses_str = f"Valid Loss: {avg_valid_loss:.3f}\n"

                for name, loss in valid_results.items():
                    if name[-4:] == 'loss':
                        loss_count = int(valid_results[name+'_count'])
                        if loss_count > 0:
                            avg_loss = loss / loss_count
                            losses_str += f"{name} ({loss_count}): {avg_loss:.3f} "
                            wandb.log({f'Valid Loss/{name}': avg_loss}, step=epoch)

                losses_str += '\n'
                print(losses_str)

            if 'qa' in self.args.losses:
                dset2score, dset2cnt, score, cnt = self.val_loader.dataset.evaluator.evaluate(valid_uid2ans)

                if len(dset2score) == 0:
                    dset2score = {'vqa': 0, 'gqa': 0, 'visual7w': 0}
                    dset2cnt = {'vqa': 1, 'gqa': 1, 'visual7w': 1}
                    cnt = 3
                    score = 0

                dset2score = reduce_dict(dset2score, average=False)
                dset2cnt = reduce_dict(dset2cnt, average=False)
                score_cnt_dict = reduce_dict({'score': score, 'cnt': cnt}, average=False)

                if self.args.gpu == 0:
                    score = score_cnt_dict['score']
                    cnt = score_cnt_dict['cnt']
                    accu = score / cnt
                    dset2accu = {}
                    for dset in dset2cnt:
                        dset2accu[dset] = dset2score[dset] / dset2cnt[dset]
                    accu_str = "Overall QA Acc %0.4f" % (accu)
                    wandb.log({f'Valid QA Acc/Overall': accu}, step=epoch)
                    sorted_keys = sorted(dset2accu.keys())
                    for key in sorted_keys:
                        accu_str += ", %s Acc %0.4f" % (key, dset2accu[key])
                        wandb.log({f'Valid QA Acc/{key}': dset2accu[key]}, step=epoch)
                    print(accu_str)
                    accu_str += '\n\n'

            dist.barrier()

            if self.verbose:
                # Save
                if avg_valid_loss < best_eval_loss:
                    best_eval_loss = avg_valid_loss
                #     self.save("BEST_EVAL_LOSS")
                self.save("Epoch%02d" % (epoch + 1))

            dist.barrier()

        if self.verbose:
            wandb.log({'finished': True})
Ejemplo n.º 17
0
    def evaluate_epoch(self, epoch):
        LOSSES_NAME = self.args.LOSSES_NAME
        task_dict = {
            'Mask_LM': 'word_mask',
            'Matched': 'matched',
            'Mask_Obj': 'vis_mask',
            'Mask_Attr': 'vis_mask',
            'Mask_Feat': 'vis_mask',
            'QA': 'qa'
        }

        epoch_results = {
            'lm_loss': 0,
            'vis_loss': 0,
            'matched_loss': 0,
            'qa_loss': 0,
            'obj_loss': 0,
            'feat_loss': 0,
            'attr_loss': 0,
        }
        for k in list(epoch_results.keys()):
            if k[-4:] == 'loss':
                epoch_results[f'{k}_count'] = 0

        uid2ans = {}

        self.model.eval()
        with torch.no_grad():
            if self.verbose:
                loss_meter = LossMeter()
                loss_meters = [LossMeter() for _ in range(len(LOSSES_NAME))]

                loss_counts = [0 for _ in range(len(LOSSES_NAME))]

                pbar = tqdm(total=len(self.val_loader), ncols=180)

            for step_i, batch in enumerate(self.val_loader):
                # task = random.choice(self.args.MASK_MODALITY)
                task_i = step_i % len(self.args.MASK_MODALITY)
                task = self.args.MASK_MODALITY[task_i]
                if self.args.vis_mask_COCO_only or self.args.vis_mask_COCOVG_only:
                    if task == 'vis_mask':
                        batch['word_id'] = batch['COCO_word_id']
                    if self.args.clustering:
                        batch['cluster_id'] = batch['COCO_cluster_id']

                results = self.forward(batch, task)

                if task == 'vis_mask':
                    epoch_results['vis_loss_count'] += 1
                    if 'Mask_Obj' in LOSSES_NAME:
                        epoch_results['obj_loss_count'] += 1
                    if 'Mask_Feat' in LOSSES_NAME:
                        epoch_results['feat_loss_count'] += 1
                    if 'Mask_Attr' in LOSSES_NAME:
                        epoch_results['attr_loss_count'] += 1
                elif task == 'word_mask':
                    epoch_results['lm_loss_count'] += 1
                elif task == 'matched':
                    epoch_results['matched_loss_count'] += 1
                elif task == 'qa':
                    epoch_results['qa_loss_count'] += 1

                if self.args.task_qa:
                    qa_pred = results['qa_pred']
                    for uid, ans_id in zip(batch['uid'],
                                           qa_pred.cpu().numpy()):
                        ans = self.train_loader.dataset.answer_table.id2ans(
                            ans_id)
                        uid2ans[uid] = ans

                for k, v in results.items():
                    if k in epoch_results:
                        epoch_results[k] += v.item()

                if self.verbose:
                    desc_str = f'Valid Epoch {epoch} | '

                    # if self.args.task_qa:
                    #     loss_meter.update(results['qa_loss'].item())
                    if self.args.task_qa:
                        loss_meter = loss_meters[-1]
                        loss_meter.update(results['qa_loss'].item())
                        loss_counts[-1] += 1

                    for i, (loss_name, loss_meter) in enumerate(
                            zip(LOSSES_NAME, loss_meters)):
                        if task_dict[loss_name] == task:
                            if task == 'vis_mask':
                                if loss_name == 'Mask_Obj':
                                    loss_meter.update(
                                        results['obj_loss'].item())
                                elif loss_name == 'Mask_Attr':
                                    loss_meter.update(
                                        results['attr_loss'].item())
                                elif loss_name == 'Mask_Feat':
                                    loss_meter.update(
                                        results['feat_loss'].item())
                            elif task == 'word_mask':
                                loss_meter.update(results['lm_loss'].item())
                            elif task == 'matched':
                                loss_meter.update(
                                    results['matched_loss'].item())
                            # elif task == 'qa':
                            #     loss_meter.update(results['qa_loss'].item())
                            loss_counts[i] += 1
                        if len(loss_meter) > 0:
                            loss_count = loss_counts[i]
                            if loss_name in [
                                    'Mask_LM', 'Matched', 'Mask_Obj',
                                    'Mask_Attr', 'Mask_Feat', 'QA'
                            ]:
                                desc_str += f' {loss_name} ({loss_count}) {loss_meter.val:.2f}'
                            else:
                                desc_str += f' {loss_name} {loss_meter.val:.2f}'

                    pbar.set_description(desc_str)
                    pbar.update(1)
                dist.barrier()

            if self.verbose:
                pbar.close()
            dist.barrier()

            if not self.args.task_qa:
                uid2ans = None

            return epoch_results, uid2ans
Ejemplo n.º 18
0
def train(cont=False):

    # for tensorboard tracking
    logger = get_logger()
    logger.info("(1) Initiating Training ... ")
    logger.info("Training on device: {}".format(device))
    writer = SummaryWriter()

    # init model
    aux_layers = None
    if net == "SETR-PUP":
        aux_layers, model = get_SETR_PUP()
    elif net == "SETR-MLA":
        aux_layers, model = get_SETR_MLA()
    elif net == "TransUNet-Base":
        model = get_TransUNet_base()
    elif net == "TransUNet-Large":
        model = get_TransUNet_large()
    elif net == "UNet":
        model = UNet(CLASS_NUM)

    # prepare dataset
    cluster_model = get_clustering_model(logger)
    train_dataset = CityscapeDataset(img_dir=data_dir,
                                     img_dim=IMG_DIM,
                                     mode="train",
                                     cluster_model=cluster_model)
    valid_dataset = CityscapeDataset(img_dir=data_dir,
                                     img_dim=IMG_DIM,
                                     mode="val",
                                     cluster_model=cluster_model)
    train_loader = DataLoader(train_dataset,
                              batch_size=batch_size,
                              shuffle=True)
    valid_loader = DataLoader(valid_dataset,
                              batch_size=batch_size,
                              shuffle=False)

    logger.info("(2) Dataset Initiated. ")

    # optimizer
    epochs = epoch_num if epoch_num > 0 else iteration_num // len(
        train_loader) + 1
    optim = SGD(model.parameters(),
                lr=lrate,
                momentum=momentum,
                weight_decay=wdecay)
    # optim = Adam(model.parameters(), lr=lrate)
    scheduler = lr_scheduler.MultiStepLR(
        optim, milestones=[int(epochs * fine_tune_ratio)], gamma=0.1)

    cur_epoch = 0
    best_loss = float('inf')
    epochs_since_improvement = 0

    # for continue training
    if cont:
        model, optim, cur_epoch, best_loss = load_ckpt_continue_training(
            best_ckpt_src, model, optim, logger)
        logger.info("Current best loss: {0}".format(best_loss))
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            for i in range(cur_epoch):
                scheduler.step()
    else:
        model = nn.DataParallel(model)
        model = model.to(device)

    logger.info("(3) Model Initiated ... ")
    logger.info("Training model: {}".format(net) + ". Training Started.")

    # loss
    ce_loss = CrossEntropyLoss()
    if use_dice_loss:
        dice_loss = DiceLoss(CLASS_NUM)

    # loop over epochs
    iter_count = 0
    epoch_bar = tqdm.tqdm(total=epochs,
                          desc="Epoch",
                          position=cur_epoch,
                          leave=True)
    logger.info("Total epochs: {0}. Starting from epoch {1}.".format(
        epochs, cur_epoch + 1))

    for e in range(epochs - cur_epoch):
        epoch = e + cur_epoch

        # Training.
        model.train()
        trainLossMeter = LossMeter()
        train_batch_bar = tqdm.tqdm(total=len(train_loader),
                                    desc="TrainBatch",
                                    position=0,
                                    leave=True)

        for batch_num, (orig_img, mask_img) in enumerate(train_loader):
            orig_img, mask_img = orig_img.float().to(
                device), mask_img.float().to(device)

            if net == "TransUNet-Base" or net == "TransUNet-Large":
                pred = model(orig_img)
            elif net == "SETR-PUP" or net == "SETR-MLA":
                if aux_layers is not None:
                    pred, _ = model(orig_img)
                else:
                    pred = model(orig_img)
            elif net == "UNet":
                pred = model(orig_img)

            loss_ce = ce_loss(pred, mask_img[:].long())
            if use_dice_loss:
                loss_dice = dice_loss(pred, mask_img, softmax=True)
                loss = 0.5 * (loss_ce + loss_dice)
            else:
                loss = loss_ce

            # Backward Propagation, Update weight and metrics
            optim.zero_grad()
            loss.backward()
            optim.step()

            # update learning rate
            for param_group in optim.param_groups:
                orig_lr = param_group['lr']
                param_group['lr'] = orig_lr * (1.0 -
                                               iter_count / iteration_num)**0.9
            iter_count += 1

            # Update loss
            trainLossMeter.update(loss.item())

            # print status
            if (batch_num + 1) % print_freq == 0:
                status = 'Epoch: [{0}][{1}/{2}]\t' \
                    'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(epoch+1, batch_num+1, len(train_loader), loss=trainLossMeter)
                logger.info(status)

            # log loss to tensorboard
            if (batch_num + 1) % tensorboard_freq == 0:
                writer.add_scalar(
                    'Train_Loss_{0}'.format(tensorboard_freq),
                    trainLossMeter.avg,
                    epoch * (len(train_loader) / tensorboard_freq) +
                    (batch_num + 1) / tensorboard_freq)
            train_batch_bar.update(1)

        writer.add_scalar('Train_Loss_epoch', trainLossMeter.avg, epoch)

        # Validation.
        model.eval()
        validLossMeter = LossMeter()
        valid_batch_bar = tqdm.tqdm(total=len(valid_loader),
                                    desc="ValidBatch",
                                    position=0,
                                    leave=True)
        with torch.no_grad():
            for batch_num, (orig_img, mask_img) in enumerate(valid_loader):
                orig_img, mask_img = orig_img.float().to(
                    device), mask_img.float().to(device)

                if net == "TransUNet-Base" or net == "TransUNet-Large":
                    pred = model(orig_img)
                elif net == "SETR-PUP" or net == "SETR-MLA":
                    if aux_layers is not None:
                        pred, _ = model(orig_img)
                    else:
                        pred = model(orig_img)
                elif net == "UNet":
                    pred = model(orig_img)

                loss_ce = ce_loss(pred, mask_img[:].long())
                if use_dice_loss:
                    loss_dice = dice_loss(pred, mask_img, softmax=True)
                    loss = 0.5 * (loss_ce + loss_dice)
                else:
                    loss = loss_ce

                # Update loss
                validLossMeter.update(loss.item())

            # print status
            if (batch_num + 1) % print_freq == 0:
                status = 'Validation: [{0}][{1}/{2}]\t' \
                    'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(epoch+1, batch_num+1, len(valid_loader), loss=validLossMeter)
                logger.info(status)

            # log loss to tensorboard
            if (batch_num + 1) % tensorboard_freq == 0:
                writer.add_scalar(
                    'Valid_Loss_{0}'.format(tensorboard_freq),
                    validLossMeter.avg,
                    epoch * (len(valid_loader) / tensorboard_freq) +
                    (batch_num + 1) / tensorboard_freq)
            valid_batch_bar.update(1)

        valid_loss = validLossMeter.avg
        writer.add_scalar('Valid_Loss_epoch', valid_loss, epoch)
        logger.info("Validation Loss of epoch [{0}/{1}]: {2}\n".format(
            epoch + 1, epochs, valid_loss))

        # update optim scheduler
        scheduler.step()

        # save checkpoint
        is_best = valid_loss < best_loss
        best_loss_tmp = min(valid_loss, best_loss)
        if not is_best:
            epochs_since_improvement += 1
            logger.info("Epochs since last improvement: %d\n" %
                        (epochs_since_improvement, ))
            if epochs_since_improvement == early_stop_tolerance:
                break  # early stopping.
        else:
            epochs_since_improvement = 0
            state = {
                'epoch': epoch,
                'loss': best_loss_tmp,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optim.state_dict(),
            }
            torch.save(state, ckpt_src)
            logger.info("Checkpoint updated.")
            best_loss = best_loss_tmp
        epoch_bar.update(1)
    writer.close()
Ejemplo n.º 19
0
    def train(self):
        if self.verbose:
            loss_meter = LossMeter()
            quesid2ans = {}
            best_valid = 0.
            print("Valid Oracle: %0.2f" %
                  (self.oracle_score(self.val_loader) * 100))

            from torch.utils.tensorboard import SummaryWriter
            self.writer = SummaryWriter(log_dir=self.args.log_dir)

            hparam_dict = {}
            for k, v in self.args.__dict__.items():
                if type(v) in [int, float, str, bool, torch.Tensor]:
                    hparam_dict[k] = v
            metric_dict = {}

            self.writer.add_hparams(hparam_dict, metric_dict)

        if self.args.distributed:
            dist.barrier()

        self.optim.zero_grad()

        for epoch in range(self.args.epochs):
            if self.start_epoch is not None:
                epoch += self.start_epoch
            self.model.train()
            if self.args.distributed:
                self.train_loader.sampler.set_epoch(epoch)
            if self.verbose:
                pbar = tqdm(total=len(self.train_loader), ncols=150)

            quesid2ans = {}
            for step_i, batch in enumerate(self.train_loader):
                update = True
                if self.args.update_freq > 1:
                    if step_i == 0:
                        update = False
                    elif step_i % self.args.update_freq == 0 or step_i == len(
                            self.train_loader) - 1:
                        update = True
                    else:
                        update = False

                vis_feats = batch['vis_feats'].cuda()
                boxes = batch['boxes'].cuda()

                input_ids = batch['word_ids'].cuda()
                target = batch['targets'].cuda()

                ques_id = batch['question_ids']

                B = len(batch['word_ids'])

                results = self.model(
                    input_ids=input_ids,
                    visual_feats=vis_feats,
                    visual_pos=boxes,
                    attention_mask=input_ids > 0,
                )
                logit = results['logit']

                assert logit.size() == target.size()
                assert logit.size() == (B, self.num_answers)

                loss = self.bce_loss(logit, target)

                loss.backward()

                if update:
                    if not self.args.no_clip_grad:
                        nn.utils.clip_grad_norm_(self.model.parameters(),
                                                 self.args.clip_grad_norm)

                    self.optim.step()
                    self.lr_scheduler.step()
                    for param in self.model.parameters():
                        param.grad = None

                try:
                    lr = self.scheduler.get_last_lr()[0]
                except AttributeError:
                    lr = self.args.lr

                if self.verbose:
                    loss_meter.update(loss.item())
                    desc_str = f'Epoch {epoch} | LR {lr:.6f} | '
                    desc_str += f'Loss {loss_meter.val:4f} |'

                    score, predict = logit.max(1)
                    predict = predict.cpu().numpy()
                    target = target.cpu().numpy()

                    for qid, pred in zip(ques_id, predict):
                        pred_ans = self.train_loader.dataset.raw_dataset.label2ans[
                            pred]
                        quesid2ans[qid] = pred_ans

                    pbar.set_description(desc_str)
                    pbar.update(1)

                if self.args.distributed:
                    dist.barrier()

                # score, label = logit.max(1)
                # for qid, l in zip(ques_id, label.cpu().numpy()):
                #     ans = dset.label2ans[l]
                #     quesid2ans[qid.item()] = ans

            if self.verbose:
                pbar.close()
                score = self.train_loader.evaluator.evaluate(quesid2ans) * 100.
                log_str = "\nEpoch %d: Train %0.2f" % (epoch, score)

                if not self.args.dry:
                    self.writer.add_scalar(f'GQA/Train/score', score, epoch)

                # Validation
                valid_score = self.evaluate(self.val_loader) * 100.
                if valid_score > best_valid:
                    best_valid = valid_score
                    self.save("BEST")

                log_str += "\nEpoch %d: Valid %0.2f" % (epoch, valid_score)
                log_str += "\nEpoch %d: Best %0.2f\n" % (epoch, best_valid)

                if not self.args.dry:
                    self.writer.add_scalar(f'GQA/Valid/score', valid_score,
                                           epoch)

                print(log_str)
                self.logger.info(log_str)

            if self.args.distributed:
                dist.barrier()

        if self.verbose:
            self.save("LAST")
Ejemplo n.º 20
0
Archivo: mmt.py Proyecto: j-min/VL-T5
    def train(self):
        if self.verbose:
            loss_meter = LossMeter()
            best_valid = 0.
            best_epoch = 0

            if 't5' in self.args.backbone:
                if self.args.use_vision:
                    project_name = "VLT5_MMT"
                else:
                    project_name = "T5_MMT"
            elif 'bart' in self.args.backbone:
                if self.args.use_vision:
                    project_name = "VLBart_MMT"
                else:
                    project_name = "Bart_MMT"

            wandb.init(project=project_name)
            wandb.run.name = self.args.run_name
            wandb.config.update(self.args)
            wandb.watch(self.model)

            src_dir = Path(__file__).resolve().parent
            base_path = str(src_dir.parent)
            src_dir = str(src_dir)
            wandb.save(os.path.join(src_dir + "/*.py"), base_path=base_path)

        if self.args.distributed:
            dist.barrier()

        global_step = 0
        for epoch in range(self.args.epochs):
            if self.start_epoch is not None:
                epoch += self.start_epoch
            self.model.train()
            if self.args.distributed:
                self.train_loader.sampler.set_epoch(epoch)
            if self.verbose:
                pbar = tqdm(total=len(self.train_loader), ncols=120)

            epoch_results = {
                'loss': 0.,

            }

            for step_i, batch in enumerate(self.train_loader):


                # self.optim.zero_grad()
                if self.args.fp16 and _use_native_amp:
                    with autocast():
                        if self.args.distributed:
                            results = self.model.module.train_step(batch)
                        else:
                            results = self.model.train_step(batch)
                else:
                    if self.args.distributed:
                        results = self.model.module.train_step(batch)
                    else:
                        results = self.model.train_step(batch)

                loss = results['loss']

                # print(f'GPU{self.args.gpu} after loss')

                if self.args.fp16 and _use_native_amp:
                    self.scaler.scale(loss).backward()
                elif self.args.fp16 and _use_apex:
                    with amp.scale_loss(loss, self.optim) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()

                # print(f'GPU{self.args.gpu} after backward')

                loss = loss.detach()

                # Update Parameters
                if self.args.clip_grad_norm > 0:
                    if self.args.fp16 and _use_native_amp:
                        self.scaler.unscale_(self.optim)
                        torch.nn.utils.clip_grad_norm_(
                            self.model.parameters(), self.args.clip_grad_norm)
                    elif self.args.fp16 and _use_apex:
                        torch.nn.utils.clip_grad_norm_(amp.master_params(
                            self.optim), self.args.clip_grad_norm)
                    else:
                        torch.nn.utils.clip_grad_norm_(
                            self.model.parameters(), self.args.clip_grad_norm)

                update = True
                if self.args.gradient_accumulation_steps > 1:
                    if step_i == 0:
                        update = False
                    elif step_i % self.args.gradient_accumulation_steps == 0 or step_i == len(self.train_loader) - 1:
                        update = True
                    else:
                        update = False

                if update:
                    if self.args.fp16 and _use_native_amp:
                        self.scaler.step(self.optim)
                        self.scaler.update()
                    else:
                        self.optim.step()

                    if self.lr_scheduler:
                        self.lr_scheduler.step()
                    # self.model.zero_grad()
                    for param in self.model.parameters():
                        param.grad = None
                    global_step += 1

                for k, v in results.items():
                    if k in epoch_results:
                        epoch_results[k] += v.item()

                if self.lr_scheduler:
                    if version.parse(torch.__version__) >= version.parse("1.4"):
                        lr = self.lr_scheduler.get_last_lr()[0]
                    else:
                        lr = self.lr_scheduler.get_lr()[0]
                else:
                    try:
                        lr = self.optim.get_lr()[0]
                    except AttributeError:
                        lr = self.args.lr

                if self.verbose:
                    loss_meter.update(loss.item())
                    desc_str = f'Epoch {epoch} | LR {lr:.6f} | Steps {global_step}'
                    desc_str += f' | Loss {loss_meter.val:4f}'
                    pbar.set_description(desc_str)
                    pbar.update(1)

                # if self.args.distributed:
                #     dist.barrier()

            if self.verbose:
                pbar.close()


                # Validation
                valid_results = self.evaluate(self.val_loader)

                valid_score = valid_results['BLEU']

                if valid_score > best_valid:
                    best_valid = valid_score
                    best_epoch = epoch
                    self.save("BEST")

                log_str = ''

                log_str += pformat(valid_results)
                log_str += "\nEpoch %d: Valid BLEU %0.4f" % (epoch, valid_score)
                log_str += "\nEpoch %d: Best BLEU %0.4f\n" % (best_epoch, best_valid)

                wandb_log_dict = {}
                wandb_log_dict['Train/Loss'] = epoch_results['loss'] / len(self.train_loader)

                for score_name, score in valid_results.items():
                    wandb_log_dict[f'Valid/{score_name}'] = score

                wandb.log(wandb_log_dict, step=epoch)

                print(log_str)

            if self.args.distributed:
                dist.barrier()

        if self.verbose:
            self.save("LAST")

            # Test Set
            best_path = os.path.join(self.args.output, 'BEST')
            self.load(best_path)

            if isinstance(self.test_loader, list):

                for loader in self.test_loader:

                    split = loader.dataset.source
                    dump_path = os.path.join(self.args.output, f'submit_{split}_raw.txt')
                    test_results = self.evaluate(loader, dump_path=dump_path)

                    wandb_log_dict = {}
                    for score_name, score in test_results.items():
                        wandb_log_dict[f'{split}/{score_name}'] = score
                    wandb.log(wandb_log_dict, step=epoch)

                    log_str = f'{split} set results\n'
                    log_str += pformat(test_results)

                    print(log_str)

                    wandb.save(dump_path, base_path=self.args.output)
                    print('\nUploaded', dump_path)

            else:
                split = loader.dataset.source
                dump_path = os.path.join(self.args.output, f'submit_{split}_raw.txt')
                test_results = self.evaluate(loader, dump_path=dump_path)

                wandb_log_dict = {}
                for score_name, score in test_results.items():
                    wandb_log_dict[f'{split}/{score_name}'] = score
                wandb.log(wandb_log_dict, step=epoch)

                log_str = f'{split} set results\n'
                log_str += pformat(test_results)

                print(log_str)

                wandb.save(dump_path, base_path=self.args.output)
                print('\nUploaded', dump_path)

            wandb.log({'finished': True})
Ejemplo n.º 21
0
class Finetuner:
    def __init__(self, args):
        self.args = args
        self.args.n_datasets = len(args.data)
        self.modelPath = Path('checkpoints') / args.expName

        self.logger = create_output_dir(args, self.modelPath)
        self.data = [DatasetSet(d, args.seq_len, args) for d in args.data]

        self.losses_recon = [
            LossMeter(f'recon {i}') for i in range(self.args.n_datasets)
        ]
        self.loss_total = LossMeter('total')

        self.evals_recon = [
            LossMeter(f'recon {i}') for i in range(self.args.n_datasets)
        ]
        self.eval_total = LossMeter('eval total')

        self.start_epoch = 0

        #torch.manual_seed(args.seed)
        #torch.cuda.manual_seed(args.seed)

        #get the pretrained model checkpoints
        checkpoint = args.checkpoint.parent.glob(args.checkpoint.name +
                                                 '_*.pth')
        checkpoint = [c for c in checkpoint
                      if extract_id(c) in args.decoder][0]

        model_args = torch.load(args.checkpoint.parent / 'args.pth')[0]

        self.encoder = Encoder(model_args)
        self.decoder = WaveNet(model_args)

        self.encoder = Encoder(model_args)
        self.encoder.load_state_dict(torch.load(checkpoint)['encoder_state'])

        #encoder freeze
        for param in self.encoder.parameters():
            param.requires_grad = False
            #self.logger.debug(f'encoder at start: {param}')

        self.decoder = WaveNet(model_args)
        self.decoder.load_state_dict(torch.load(checkpoint)['decoder_state'])

        #decoder freeze
        for param in self.decoder.layers[:-args.decoder_update].parameters():
            param.requires_grad = False
            #self.logger.debug(f'decoder at start: {param}')

        self.encoder = torch.nn.DataParallel(self.encoder).cuda()
        self.decoder = torch.nn.DataParallel(self.decoder).cuda()
        self.model_optimizer = optim.Adam(chain(self.encoder.parameters(),
                                                self.decoder.parameters()),
                                          lr=args.lr)

        self.lr_manager = torch.optim.lr_scheduler.ExponentialLR(
            self.model_optimizer, args.lr_decay)
        self.lr_manager.step()

    def train_batch(self, x, x_aug, dset_num):
        'train batch without considering the discriminator'
        x = x.float()
        x_aug = x_aug.float()
        z = self.encoder(x_aug)
        y = self.decoder(x, z)

        recon_loss = cross_entropy_loss(y, x)
        self.losses_recon[dset_num].add(recon_loss.data.cpu().numpy().mean())
        loss = recon_loss.mean()

        self.model_optimizer.zero_grad()
        loss.backward()
        self.model_optimizer.step()
        self.loss_total.add(loss.data.item())

        return loss.data.item()

    def train_epoch(self, epoch):
        for meter in self.losses_recon:
            meter.reset()
        self.loss_total.reset()

        self.decoder.train()

        n_batches = self.args.epoch_len

        with tqdm(total=n_batches,
                  desc='Train epoch %d' % epoch) as train_enum:
            for batch_num in range(n_batches):
                if self.args.short and batch_num == 3:
                    break

                if self.args.distributed:
                    assert self.args.rank < self.args.n_datasets, "No. of workers must be equal to #dataset"
                    # dset_num = (batch_num + self.args.rank) % self.args.n_datasets
                    dset_num = self.args.rank
                else:
                    dset_num = batch_num % self.args.n_datasets

                x, x_aug = next(self.data[dset_num].train_iter)

                x = wrap(x)
                x_aug = wrap(x_aug)
                batch_loss = self.train_batch(x, x_aug, dset_num)

                train_enum.set_description(
                    f'Train (loss: {batch_loss:.2f}) epoch {epoch}')
                train_enum.update()

    def eval_batch(self, x, x_aug, dset_num):
        x, x_aug = x.float(), x_aug.float()
        z = self.encoder(x)
        y = self.decoder(x, z)

        recon_loss = cross_entropy_loss(y, x)
        self.evals_recon[dset_num].add(recon_loss.data.cpu().numpy().mean())

        total_loss = recon_loss.mean().data.item()
        self.eval_total.add(total_loss)

        return total_loss

    def evaluate_epoch(self, epoch):
        for meter in self.evals_recon:
            meter.reset()
        self.eval_total.reset()

        self.encoder.eval()
        self.decoder.eval()

        n_batches = int(np.ceil(self.args.epoch_len / 10))

        with tqdm(total=n_batches) as valid_enum, torch.no_grad():
            for batch_num in range(n_batches):
                if self.args.short and batch_num == 10:
                    break

                if self.args.distributed:
                    assert self.args.rank < self.args.n_datasets, "No. of workers must be equal to #dataset"
                    dset_num = self.args.rank
                else:
                    dset_num = batch_num % self.args.n_datasets

                x, x_aug = next(self.data[dset_num].valid_iter)

                x = wrap(x)
                x_aug = wrap(x_aug)
                batch_loss = self.eval_batch(x, x_aug, dset_num)

                valid_enum.set_description(
                    f'Test (loss: {batch_loss:.2f}) epoch {epoch}')
                valid_enum.update()

    @staticmethod
    def format_losses(meters):
        losses = [meter.summarize_epoch() for meter in meters]
        return ', '.join('{:.4f}'.format(x) for x in losses)

    def train_losses(self):
        meters = [*self.losses_recon]
        return self.format_losses(meters)

    def eval_losses(self):
        meters = [*self.evals_recon]
        return self.format_losses(meters)

    def finetune(self):
        best_eval = float('inf')

        for epoch in range(self.start_epoch,
                           self.start_epoch + self.args.epochs):
            self.logger.info(
                f'Starting epoch, Rank {self.args.rank}, Dataset: {self.args.data[self.args.rank]}'
            )
            self.train_epoch(epoch)
            self.evaluate_epoch(epoch)

            self.logger.info(
                f'Epoch %s Rank {self.args.rank} - Train loss: (%s), Test loss (%s)',
                epoch, self.train_losses(), self.eval_losses())
            self.lr_manager.step()
            val_loss = self.eval_total.summarize_epoch()

            if val_loss < best_eval:
                self.save_model(f'bestmodel_{self.args.rank}.pth')
                best_eval = val_loss

            if not self.args.per_epoch:
                self.save_model(f'lastmodel_{self.args.rank}.pth')
            else:
                self.save_model(f'lastmodel_{epoch}_rank_{self.args.rank}.pth')

            if self.args.is_master:
                torch.save([self.args, epoch], '%s/args.pth' % self.modelPath)
            self.logger.debug('Ended epoch')

    def save_model(self, filename):
        save_path = self.modelPath / filename

        torch.save(
            {
                'encoder_state': self.encoder.module.state_dict(),
                'decoder_state': self.decoder.module.state_dict(),
                'model_optimizer_state': self.model_optimizer.state_dict(),
                'dataset': self.args.rank,
            }, save_path)

        self.logger.debug(f'Saved model to {save_path}')