def train_loop_fn(loader):
        tracker = xm.RateTracker()

        model.train()
        for x, (data, target) in enumerate(loader):
            optimizer.zero_grad()
            output = model(data)
            loss = loss_fn(output, target)
            loss.backward()
            xm.optimizer_step(optimizer)
            tracker.add(FLAGS.batch_size)
            if x % FLAGS.log_steps == 0:
                xm.add_step_closure(_train_update,
                                    args=(device, x, loss, tracker))
Esempio n. 2
0
 def test_loop_fn(loader, epoch):
     total_samples, correct = 0, 0
     model.eval()
     for step, (data, target) in enumerate(loader):
         output = model(data)
         pred = output.max(1, keepdim=True)[1]
         correct += pred.eq(target.view_as(pred)).sum()
         total_samples += data.size()[0]
         if step % FLAGS.log_steps == 0:
             xm.add_step_closure(test_utils.print_test_update,
                                 args=(device, None, epoch, step))
     accuracy = 100.0 * correct.item() / total_samples
     accuracy = xm.mesh_reduce('test_accuracy', accuracy, np.mean)
     return accuracy
Esempio n. 3
0
 def train_loop_fn(loader):
   tracker = xm.RateTracker()
   model.train()
   for step, (data, target) in enumerate(loader):
     optimizer.zero_grad()
     output = model(data)
     loss = loss_fn(output, target)
     loss.backward()
     xm.optimizer_step(optimizer)
     tracker.add(flags.batch_size)
     if step % flags.log_steps == 0:
       xm.add_step_closure(
           _train_update,
           args=(device, step, loss, tracker, writer),
           run_async=FLAGS.async_closures)
Esempio n. 4
0
    def test_synchronous_exception(self):
        flag = Event()
        assert not flag.is_set()

        try:

            def closure():
                flag.set()
                raise RuntimeError("Simulating exception in closure")

            xm.add_step_closure(closure)
            xm.mark_step()

            assert False  # Should not reach here
        except RuntimeError as e:
            assert flag.is_set(), "Should have caught exception from closure"
Esempio n. 5
0
 def train_loop_fn(loader, epoch):
     tracker = xm.RateTracker()
     model.train()
     for step, (data, target) in enumerate(loader):
         optimizer.zero_grad()
         output = model(data)
         loss = loss_fn(output, target)
         loss.backward()
         optimizer.step()  # do not reduce gradients on sharded params
         tracker.add(FLAGS.batch_size)
         if lr_scheduler:
             lr_scheduler.step()
         if step % FLAGS.log_steps == 0:
             xm.add_step_closure(_train_update,
                                 args=(device, step, loss, tracker, epoch,
                                       writer))
Esempio n. 6
0
 def train_loop_fn(loader):
     tracker = xm.RateTracker()
     model.train()
     for step, (data, target) in enumerate(loader):
         optimizer.zero_grad()
         with autocast():
             output = model(data)
             loss = loss_fn(output, target)
         scaler.scale(loss).backward()
         gradients = xm._fetch_gradients(optimizer)
         xm.all_reduce('sum', gradients, scale=1.0 / xm.xrt_world_size())
         scaler.step(optimizer)
         scaler.update()
         tracker.add(flags.batch_size)
         if step % flags.log_steps == 0:
             xm.add_step_closure(_train_update,
                                 args=(device, step, loss, tracker, writer))
Esempio n. 7
0
    def train_loop_fn(loader):
        tracker = xm.RateTracker()
        model.train()
        for step, (data, target) in enumerate(loader):
            with xp.StepTrace('train_mnist', step_num=step):
                with xp.Trace('build_graph'):
                    optimizer.zero_grad()
                    output = model(data)
                    loss = loss_fn(output, target)
                    loss.backward()
                xm.optimizer_step(optimizer)

                tracker.add(flags.batch_size)
                if step % flags.log_steps == 0:
                    xm.add_step_closure(_train_update,
                                        args=(device, step, loss, tracker,
                                              writer))
Esempio n. 8
0
    def train_loop_fn(device, trainer, loader, last_batch_index):
        """
        This is the main training loop. It trains for 1 epoch.
        """

        def print_training_update(trainer, progress, args, i):
            stats = get_training_stats(trainer, args=args)
            stats['now'] = now()
            progress.log(stats, tag='train', step=trainer.get_num_updates())
            progress.print_mid_epoch(i+1, force=True)

        stats, log_output, skip_stat_keys = None, None, {'clip'}
        max_update = args.max_update or math.inf
        for i, samples in enumerate(loader, start=epoch_itr.iterations_in_epoch):
            if i == last_batch_index:
                # last batches are incomplete
                break
            log_output = trainer.train_step(samples)
            reset_perf_training_meters(trainer, i, ignore_index=10)
            if (not (i % args.log_steps)) or (i == last_batch_index-1):
                step_args = trainer, progress, args, i
                xm.add_step_closure(print_training_update, args=step_args)
            num_updates = trainer.get_num_updates()
            if (
                not args.disable_validation
                and args.save_interval_updates > 0
                and num_updates % args.save_interval_updates == 0
                and num_updates > 0
            ):
                vloss = validate_subset(
                    args, device, trainer, task, epoch_itr, valid_subsets[0]
                )
                checkpoint_utils.save_checkpoint(
                    args, trainer, epoch_itr, vloss.item(),
                    epoch=epoch, end_of_epoch=False,
                )
            if num_updates >= max_update:
                break
Esempio n. 9
0
def validate(args, e, mdl, opt, ls_f, vd_dl, tp_ix, tp_rix, ls_mtr, tb_sw):
    vd_ls = Integer(0)
    mtr = Metric()

    vd = pl.ParallelLoader(vd_dl, [
        args.dvc,
    ]).per_device_loader(args.dvc) if args.tpu else vd_dl

    mdl.eval()
    with torch.no_grad():
        for p, n, w, x, y in vd:
            p = p.view(-1, p.shape[-1]).to(args.aux_dvc)
            n = n.view(-1, n.shape[-1]).to(args.aux_dvc)
            x = x.view(-1, x.shape[-1]).to(args.aux_dvc)

            ls = _loss(args, p, n, w, mdl, ls_f)
            if args.tpu:
                xm.add_step_closure(_validate, args=(vd_ls, ls))
            else:
                _validate(vd_ls, ls)

            evaluate(args, x, y, mdl, tp_ix, tp_rix, mtr)

    vd_ls.val /= len(vd_dl)
    vd_ls_avg = vd_ls.val if args.tpu else _allreduce(
        vd_ls.val, 'validate.vd_ls_avg', hvd.Average)
    if not args.tpu:
        mtr.allreduce()

    if is_master(args):
        print(f'Epoch {e}/{args.epochs} validation loss: {vd_ls_avg}')
        print(mtr)
        tb_sw.add_scalar(f'loss/validation', vd_ls_avg, e)
        tb_sw.add_scalars('validation', dict(mtr))

        is_bst, bst_ls = ls_mtr.update(vd_ls_avg)
        _checkpoint(args, e, mdl, opt, bst_ls, is_bst)
Esempio n. 10
0
def train(rank, args):
    print('enter train @ %s'%(rank), flush=True)
    args.rank = rank
    args.split = ''
    torch.manual_seed(42)
    save_fn = os.path.join(args.save_dir, 'checkpoint_final.pt')

    tokenizer = get_tokenizer(args)
    args.vocab_size = tokenizer._tokenizer.get_vocab_size() if not args.vocab_size else args.vocab_size
    
    train_dataset = get_dataset(args)
    
    batched_already = hasattr(train_dataset, '__getbatch__')

    if args.total_num_updates < 100:
        args.total_num_updates = len(train_dataset) * args.total_num_updates

    if args.warmup_updates < 1:
        args.warmup_updates = int(args.total_num_updates * args.warmup_updates)
    else:
        args.warmup_updates = int(args.warmup_updates)
    
    train_sampler = None
    if args.gpus:
        dist.init_process_group(
            'nccl', 
            rank=rank, 
            world_size=args.world_size
        )
        if args.gpus > 1:
            train_sampler = torch.utils.data.distributed.DistributedSampler(
                train_dataset,
                num_replicas=args.gpus,
                rank=rank,
                shuffle=args.shuffle)

    else:
        rank = xm.get_ordinal()
        if xm.xrt_world_size() > 1:
            train_sampler = torch.utils.data.distributed.DistributedSampler(
                train_dataset,
                num_replicas=xm.xrt_world_size(),
                rank=rank,
                shuffle=args.shuffle)


    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=args.batch_size if not batched_already else None,
        sampler=train_sampler,
        pin_memory=True,
        shuffle=False,
        num_workers=args.num_workers)
        

    eval_loaders = []
    if args.eval_dir:
        for split in args.splits.split(','):
            split = split.strip()
            eval_sampler = None
            if args.gpus:
                if args.gpus > 1:
                    eval_sampler = torch.utils.data.distributed.DistributedSampler(
                        train_dataset,
                        num_replicas=args.gpus,
                        rank=rank,
                        shuffle=False)

            else:
                rank = xm.get_ordinal()
                if xm.xrt_world_size() > 1:
                    eval_sampler = torch.utils.data.distributed.DistributedSampler(
                        train_dataset,
                        num_replicas=xm.xrt_world_size(),
                        rank=rank,
                        shuffle=False)

            args.split = split
            eval_dataset = get_eval_dataset(args)
            eval_loader = torch.utils.data.DataLoader(
                eval_dataset,
                batch_size=args.batch_size if not batched_already else None,
                sampler=eval_sampler,
                pin_memory=True,
                shuffle=False,
                num_workers=args.num_workers)
            eval_loaders.append(eval_loader)

    if args.gpus:
        assert apex_enabled
        torch.cuda.set_device(rank)


        ##########################
        ##
        ##  Model Creation
        ##
        ##########################
        model = get_model(args, tokenizer)

        model.cuda(rank)

        device = torch.device('cuda:'+str(rank))

        ##########################
        ##
        ##  Init Optimizer
        ##
        ##########################

        optimizer = apex.optimizers.FusedAdam(
            model_get_parameters(model,
                                 lr=args.lr,
                                 lw_lr_decay=args.lw_lr_decay,
                                 weight_decay=args.weight_decay,
                                 special_layer_wise_lr=args.special_layer_wise_lr,
                                 log = rank == 0,
                                 ),  

                                 # use this function to set extra optimizer arguments, 
                                 # see model_get_parameters
            betas=(0.9, 0.999), 
            eps=1e-6,
            lr=args.lr, 
            weight_decay=args.weight_decay
        )




        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
        model = DDP(model)
        batches = train_loader

    else:
        assert tpu_enabled
        device = xm.xla_device()


        ##########################
        ##
        ##  Model Creation
        ##
        ##########################
        
        model = get_model(args, tokenizer)


        ##########################
        ##
        ##  For shared parameters, TPU requires modules to be tied after .to(device)
        ##  So we first find the shared parameters first
        ##
        ##########################

        shared_parameters = {e[0]: e[1:] for e in _catalog_shared_params(model)}

        model.to(device)
        
        do_share_parameters_again(model, shared_parameters, log = rank == 0)



        ##########################
        ##
        ##  Init Optimizer
        ##
        ##########################

        optimizer = optim.Adam(
            model_get_parameters(model,
                                 lr=args.lr,
                                 lw_lr_decay=args.lw_lr_decay,
                                 weight_decay=args.weight_decay
                                 ),  

                                 # use this function to set extra optimizer arguments, 
                                 # see model_get_parameters
            lr=args.lr,
            weight_decay=args.weight_decay
        )


        writer = None
        if xm.is_master_ordinal():
            writer = test_utils.get_summary_writer(args.save_dir)
                
        xm.rendezvous("load_checkpoint")  # wait for all workers
        xm.mark_step()

        # tracker = xm.RateTracker()
        
        
        
    if args.restore_file:
        states = torch.load(args.restore_file, map_location=device)
        for k, v in list(states.items()):
            if k.startswith('module.'):
                del states[k]
                k = k[7:]
                states[k] = v
            if k.endswith('position_ids'):
                del states[k]
                states[k[:-12] + 'position_embeddings'] = v
                
        if args.gpus:
            states = {"module.%s"%k : v for k, v in states.items()}
        try:
            model.load_state_dict(states)
        except Exception as err:
            import traceback
            if rank == 0:
                traceback.print_exc()
            model.load_state_dict(states, strict=False)
            
        
    if rank == 0:
        if not os.path.exists(os.path.dirname(save_fn)):
            try:
                os.makedirs(os.path.dirname(save_fn))
            except OSError as exc: # Guard against race condition
                if exc.errno != errno.EEXIST:
                    raise
        if args.gpus:
            torch.save(model.state_dict(), save_fn )
        else:
            xm.save(model.state_dict(), save_fn )
        
    model.train()

    if args.anomaly_detection and rank == 0:
        torch.set_anomaly_enabled(True)

    ##########################
    ##
    ##  Init LR Scheduler
    ##
    ##########################
    
    if not batched_already:
        args.total_num_updates = args.total_num_updates // args.batch_size
        args.warmup_updates = args.total_num_updates // args.batch_size
        
        
    args.total_num_updates = args.total_num_updates // args.world_size
    args.warmup_updates = args.total_num_updates // args.world_size

    scheduler = get_linear_schedule_with_warmup(
        optimizer, 
        num_warmup_steps=args.warmup_updates, 
        num_training_steps=args.total_num_updates, 
    )

    step_i = 0

    err = None
    tb = None
    #tb = SummaryWriter()
    try:
        if rank == 0:
            pbar = tqdm(total=args.total_num_updates, file=sys.stdout)
        while step_i < args.total_num_updates:
            if not args.gpus:
                batches = pl.ParallelLoader(train_loader, [device]).per_device_loader(device)
                
            n_samples = len(batches)
                
            for sample in batches:
                step_i += 1
                if step_i > args.total_num_updates:
                    break

                report_step = step_i % args.log_interval == 0

                while True: # the loop only for apex Gradient Overflow
                    optimizer.zero_grad()

                    total_loss, log = get_loss(
                        model, 
                        sample, 
                        args=args, 
                        device=device, 
                        gpus=args.gpus, 
                        report=report_step
                    )

                    if args.gpus:
                        default_optimizer_step = optimizer.step

                        with amp.scale_loss(total_loss, optimizer) as scaled_loss:
                            scaled_loss.backward()

                        # If Amp detects an overflow, it patches optimizer.step.  In other words, if optimizer.step
                        # was left unpatched, there was no overflow, and we don't need to replay.
                        if optimizer.step is default_optimizer_step:
                            optimizer.step()
                            break

                        optimizer.step() # If an overflow was detected, "optimizer.step" is the patched call, which does 
                                         # nothing but restore optimizer.step to default_optimizer_step.
                        if rank == 0:
                            print("Overflowed, reducing loss scale and replaying batch.", flush=True)
                        
                    else:
                        total_loss.backward()
                        xm.optimizer_step(optimizer)
                        xm.mark_step()

                        break



                scheduler.step()

                if report_step:
                    if 'loss' not in log:
                        log['loss'] = total_loss

                    # tb.add_scalar("Loss", total_loss, step_i)

                    for k, v in log.items():
                        try:
                            dist.all_reduce(v, op=dist.reduce_op.SUM)
                            log[k] = float(v)
                        except Exception as e:
                            print(v, e)
                            pass
                        
                    if args.gpus:
                        if rank == 0:
                            pbar.set_description(format_log(log, log_formatter, tb, step_i))
                    else:
                        xm.add_step_closure(_train_update, args=(log, log_formatter, tb, step_i))

                    if args.report_metrics:
                        xm.master_print(met.metrics_report())

                
                if rank == 0:
                    pbar.update(1)

        if rank == 0:
            pbar.close()
        if eval_loaders:
            model.half()
            model.eval()
            model.cuda()
            for k, v in model.named_parameters():
                v.requires_grad =False

                
            for split, eval_loader in zip(args.splits.split(','), eval_loaders):
                batches = eval_loader
                if rank == 0:
                    eval_length = len(batches)
                    if not batched_already:
                        eval_length = eval_length // args.batch_size

                    eval_length = eval_length // args.world_size

                    pbar = tqdm(total=eval_length, file=sys.stdout)
                
                if not args.gpus:
                    batches = pl.ParallelLoader(eval_loader, [device]).per_device_loader(device)
                with torch.no_grad():
                    record = OrderedDict()

                    for sample in batches:
                        evaluate(
                            model, 
                            sample, 
                            args=args, 
                            device=device, 
                            record=record,
                            gpus=args.gpus, 
                            report=False
                        )
                        if rank == 0:
                            pbar.update(1)

                    for k, v in record.items():
                        try:
                            def handle_reduce(v):
                                if len(v.shape) == 0:
                                    dist.all_reduce(v, op=dist.reduce_op.SUM)
                                else:
                                    L = [torch.ones_like(v) for _ in range(dist.get_world_size())]
                                    dist.all_gather(L, v)
                                    v = torch.car(L, dim=0)
                                return v
                            if isinstance(v, list):
                                v = [handle_reduce(e) for e in v]
                            else:
                                v = handle_reduce(v)
                            record[k] = float(v)
                        except Exception as e:
                            pass

                    post_evaluate(record, args=args)

                import json

                if rank == 0:
                    print('',flush=True)
                    print('Test result for %s'%split, flush=True)
                    print(json.dumps(record, indent=2),flush=True)
                    print('',flush=True)


    except Exception as _err:
        err = _err
    finally:
        folder = os.path.split(os.path.abspath(save_fn))[0]
        os.makedirs(folder, exist_ok=True)
        if rank == 0:
            print("Saving to %s"%save_fn)
            if args.gpus:
                torch.save(model.state_dict(), save_fn )
                if err:
                    raise err
            else:
                xm.save(model.state_dict(), save_fn )
                if err:
                    raise err
            print("Saved to %s"%save_fn)
Esempio n. 11
0
def train(rank, args):
    print('enter train @ %s' % (rank), flush=True)
    args.rank = rank
    torch.manual_seed(42)

    tokenizer = get_tokenizer(args)
    args.vocab_size = tokenizer._tokenizer.get_vocab_size()

    train_dataset = get_dataset(args)

    if args.total_num_updates < 100:
        args.total_num_updates = len(train_dataset) * args.total_num_updates

    if args.warmup_updates < 1:
        args.warmup_updates = int(args.total_num_updates * args.warmup_updates)
    else:
        args.warmup_updates = int(args.warmup_updates)

    train_sampler = None
    if args.gpus:
        dist.init_process_group('nccl', rank=rank, world_size=args.world_size)
        if args.gpus > 1:
            train_sampler = torch.utils.data.distributed.DistributedSampler(
                train_dataset,
                num_replicas=args.gpus,
                rank=rank,
                shuffle=False)

    else:
        rank = xm.get_ordinal()
        if xm.xrt_world_size() > 1:
            train_sampler = torch.utils.data.distributed.DistributedSampler(
                train_dataset,
                num_replicas=xm.xrt_world_size(),
                rank=rank,
                shuffle=False)

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=args.batch_size
        if not hasattr(train_dataset, '__getbatch__') else None,
        sampler=train_sampler,
        pin_memory=True,
        shuffle=False,
        num_workers=args.num_workers)

    eval_loader = None
    if args.eval_dir:

        eval_sampler = None
        if args.gpus:
            dist.init_process_group('nccl',
                                    rank=rank,
                                    world_size=args.world_size)
            if args.gpus > 1:
                traieval_samplern_sampler = torch.utils.data.distributed.DistributedSampler(
                    train_dataset,
                    num_replicas=args.gpus,
                    rank=rank,
                    shuffle=False)

        else:
            rank = xm.get_ordinal()
            if xm.xrt_world_size() > 1:
                eval_sampler = torch.utils.data.distributed.DistributedSampler(
                    train_dataset,
                    num_replicas=xm.xrt_world_size(),
                    rank=rank,
                    shuffle=False)

        eval_dataset = get_eval_dataset(args)
        eval_loader = torch.utils.data.DataLoader(
            eval_dataset,
            batch_size=args.batch_size
            if not hasattr(train_dataset, '__getbatch__') else None,
            sampler=eval_sampler,
            pin_memory=True,
            shuffle=False,
            num_workers=args.num_workers)

    if args.gpus:
        assert apex_enabled
        torch.cuda.set_device(rank)

        ##########################
        ##
        ##  Model Creation
        ##
        ##########################
        model = get_model(args)

        model.cuda(rank)

        device = torch.device('cuda:' + str(rank))

        ##########################
        ##
        ##  Init Optimizer
        ##
        ##########################

        optimizer = apex.optimizers.FusedAdam(
            model_get_parameters(model,
                                 lr=args.lr,
                                 lw_lr_decay=args.lw_lr_decay,
                                 weight_decay=args.weight_decay),

            # use this function to set extra optimizer arguments,
            # see model_get_parameters
            betas=(0.9, 0.999),
            eps=1e-6,
            lr=args.lr,
            weight_decay=args.weight_decay)

        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
        model = DDP(model)
        batches = train_loader

    else:
        assert tpu_enabled
        device = xm.xla_device()

        ##########################
        ##
        ##  Model Creation
        ##
        ##########################

        model = get_model(args)

        ##########################
        ##
        ##  For shared parameters, TPU requires modules to be tied after .to(device)
        ##  So we first find the shared parameters first
        ##
        ##########################

        shared_parameters = {
            e[0]: e[1:]
            for e in _catalog_shared_params(model)
        }

        model.to(device)

        do_share_parameters_again(model, shared_parameters, log=rank == 0)

        ##########################
        ##
        ##  Init Optimizer
        ##
        ##########################

        optimizer = optim.Adam(
            model_get_parameters(model,
                                 lr=args.lr,
                                 lw_lr_decay=args.lw_lr_decay,
                                 weight_decay=args.weight_decay),

            # use this function to set extra optimizer arguments,
            # see model_get_parameters
            lr=args.lr,
            weight_decay=args.weight_decay)

        writer = None
        if xm.is_master_ordinal():
            writer = test_utils.get_summary_writer(args.save_dir)

        xm.rendezvous("load_checkpoint")  # wait for all workers
        xm.mark_step()

        # tracker = xm.RateTracker()
    if args.restore_file:
        states = torch.load(args.restore_file, map_location=device)
        for k, v in list(states.items()):
            if k.startswith('module.'):
                del states[k]
                k = k[7:]
                states[k] = v
            if k.endswith('position_ids'):
                del states[k]
                states[k[:-12] + 'position_embeddings'] = v
        try:
            model.load_state_dict(states)
        except Exception as err:
            import traceback
            traceback.print_exc()
            model.load_state_dict(states, strict=False)

    model.train()

    if args.anomaly_detection and rank == 0:
        torch.set_anomaly_enabled(True)

    ##########################
    ##
    ##  Init LR Scheduler
    ##
    ##########################

    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=args.warmup_updates,
        num_training_steps=args.total_num_updates,
    )

    step_i = 0

    err = None
    try:
        if rank == 0:
            pbar = tqdm(total=args.total_num_updates)
        while step_i < args.total_num_updates:
            if not args.gpus:
                batches = pl.ParallelLoader(train_loader,
                                            [device]).per_device_loader(device)
            for sample in batches:
                step_i += 1
                if step_i > args.total_num_updates:
                    break

                report_step = step_i % args.log_interval == 0

                while True:  # the loop only for apex Gradient Overflow
                    optimizer.zero_grad()

                    total_loss, log = get_loss(model,
                                               sample,
                                               args=args,
                                               device=device,
                                               gpu=args.gpus,
                                               report=report_step)

                    if args.gpus:
                        default_optimizer_step = optimizer.step

                        with amp.scale_loss(total_loss,
                                            optimizer) as scaled_loss:
                            scaled_loss.backward()

                        # If Amp detects an overflow, it patches optimizer.step.  In other words, if optimizer.step
                        # was left unpatched, there was no overflow, and we don't need to replay.
                        if optimizer.step is default_optimizer_step:
                            optimizer.step()
                            break

                        optimizer.step(
                        )  # If an overflow was detected, "optimizer.step" is the patched call, which does
                        # nothing but restore optimizer.step to default_optimizer_step.
                        if rank == 0:
                            print(
                                "Overflowed, reducing loss scale and replaying batch.",
                                flush=True)

                    else:
                        total_loss.backward()
                        xm.optimizer_step(optimizer)
                        xm.mark_step()

                        break

                scheduler.step()

                if report_step:
                    if 'loss' not in log:
                        log['loss'] = total_loss

                    if args.gpus:
                        if rank == 0:
                            pbar.set_description(format_log(
                                log, log_formatter))
                    else:
                        xm.add_step_closure(_train_update,
                                            args=(log, log_formatter))

                    if args.report_metrics:
                        xm.master_print(met.metrics_report())

                if rank == 0:
                    pbar.update(1)

        if eval_loader is not None:
            model.eval()
            if not args.gpus:
                batches = pl.ParallelLoader(eval_loader,
                                            [device]).per_device_loader(device)
            with torch.no_grad():
                record = OrderedDict()

                for sample in batches:
                    evaluate(model,
                             sample,
                             args=args,
                             device=device,
                             record=record,
                             gpu=args.gpus,
                             report=report_step)

                post_evaluate(record, args=args)

            import json
            print('', flush=True)
            print(json.dumps(record), flush=True)
            print('', flush=True)

    except Exception as _err:
        err = _err
    finally:
        save_fn = os.path.join(args.save_dir, 'checkpoint_final.pt')
        folder = os.path.split(os.path.abspath(save_fn))[0]
        os.makedirs(folder, exist_ok=True)
        if rank == 0 and args.gpus:
            torch.save(model.state_dict(), save_fn)
            if err:
                raise err
        else:
            xm.save(model.state_dict(), save_fn)
            if err:
                raise err
Esempio n. 12
0
    def train(self):
        hps = self.state.hps
        ss = self.state
        current_stats = {}
        writer_stats = {}

        # for resuming the learning rate
        sorted_lr_steps = sorted(self.learning_rates.keys())
        lr_index = util.greatest_lower_bound(sorted_lr_steps,
                                             ss.data.global_step)
        ss.update_learning_rate(self.learning_rates[sorted_lr_steps[lr_index]])

        if ss.model.bn_type != 'none':
            sorted_as_steps = sorted(self.anneal_schedule.keys())
            as_index = util.greatest_lower_bound(sorted_as_steps,
                                                 ss.data.global_step)
            ss.model.objective.update_anneal_weight(
                self.anneal_schedule[sorted_as_steps[as_index]])

        if ss.model.bn_type in ('vqvae', 'vqvae-ema'):
            ss.model.init_codebook(self.data_iter, 10000)

        start_time = time.time()

        for batch_num, batch in enumerate(self.device_loader):
            wav, mel, voice, jitter, position = batch
            global_step = len(ss.data.dataset) * position[0] + position[1]

            # print(f'replica {self.replica_index}, batch {batch_num}', file=stderr)
            # stderr.flush()
            if (batch_num % hps.save_interval == 0 and batch_num != 0):
                self.save_checkpoint(position)

            if hps.skip_loop_body:
                continue

            lr_index = util.greatest_lower_bound(sorted_lr_steps, global_step)
            ss.update_learning_rate(
                self.learning_rates[sorted_lr_steps[lr_index]])
            # if ss.data.global_step in self.learning_rates:
            # ss.update_learning_rate(self.learning_rates[ss.data.global_step])

            if ss.model.bn_type == 'vae' and ss.step in self.anneal_schedule:
                ss.model.objective.update_anneal_weight(
                    self.anneal_schedule[ss.data.global_step])

            ss.optim.zero_grad()
            quant, self.target, loss = self.state.model.run(
                wav, mel, voice, jitter)
            self.probs = self.softmax(quant)
            self.mel_enc_input = mel
            # print(f'after model.run', file=stderr)
            # stderr.flush()
            loss.backward()

            # print(f'after loss.backward()', file=stderr)
            # stderr.flush()

            if batch_num % hps.progress_interval == 0:
                pars_copy = [p.data.clone() for p in ss.model.parameters()]

            # print(f'after pars_copy', file=stderr)
            # stderr.flush()

            if self.is_tpu:
                xm.optimizer_step(ss.optim)
            else:
                ss.optim.step()

            ss.optim_step += 1

            if ss.model.bn_type == 'vqvae-ema' and ss.data.global_step == 10000:
                ss.model.bottleneck.update_codebook()

            tprb_m = self.avg_prob_target()

            if batch_num % hps.progress_interval == 0:
                iterator = zip(pars_copy, ss.model.named_parameters())
                uw_ratio = {
                    np[0]: t.norm(c - np[1].data) / c.norm()
                    for c, np in iterator
                }

                writer_stats.update({'uwr': uw_ratio})

                if self.is_tpu:
                    count = torch_xla._XLAC._xla_get_replication_devices_count(
                    )
                    loss_red, tprb_red = xm.all_reduce('sum', [loss, tprb_m],
                                                       scale=1.0 / count)
                    # loss_red = xm.all_reduce('all_loss', loss, reduce_mean)
                    # tprb_red = xm.all_reduce('all_tprb', tprb_m, reduce_mean)
                else:
                    loss_red = loss
                    tprb_red = tprb_m

                writer_stats.update({
                    'loss_r': loss_red,
                    'tprb_r': tprb_red,
                    'optim_step': ss.optim_step
                })

                current_stats.update({
                    'optim_step': ss.optim_step,
                    'gstep': global_step,
                    # 'gstep': ss.data.global_step,
                    'epoch': position[0],
                    'step': position[1],
                    # 'loss': loss,
                    'lrate': ss.optim.param_groups[0]['lr'],
                    # 'tprb_m': tprb_m,
                    # 'pk_d_m': avg_peak_dist
                })
                current_stats.update(ss.model.objective.metrics)

                if ss.model.bn_type in ('vae'):
                    current_stats['free_nats'] = ss.model.objective.free_nats
                    current_stats['anneal_weight'] = \
                            ss.model.objective.anneal_weight.item()

                if ss.model.bn_type in ('vqvae', 'vqvae-ema', 'ae', 'vae'):
                    current_stats.update(ss.model.encoder.metrics)

                if self.is_tpu:
                    xm.add_step_closure(self.train_update,
                                        args=(writer_stats, current_stats))
                else:
                    self.train_update(writer_stats, current_stats)

                # if not self.is_tpu or xm.is_master_ordinal():
                # if batch_num in range(25, 50) or batch_num in range(75, 100):
                stderr.flush()
                elapsed = time.time() - start_time