示例#1
0
文件: model.py 项目: jgalle29/gait
    def __init__(self, model=None, optimizer=None, learning_rate=-1, momentum=-1, cuda=True, load_file=None,
                 save_every=500, save_file=None, max_save_files=5, debug=False):
        self.model = model
        self.save_every = save_every
        self.save_file = save_file
        self.max_save_files = max_save_files
        self.debug = debug
        if debug:
            torch.set_anomaly_enabled(True)

        if isinstance(optimizer, str):
            if optimizer == 'adam':
                if learning_rate < 0.0:
                    learning_rate = 1e-3
                self.optimizer = optim.Adam(self.model.parameters(), lr=learning_rate)
            elif optimizer == 'rmspropc':
                if learning_rate < 0.0:
                    learning_rate = 0.01
                if momentum < 0.0:
                    momentum = 0.0
                self.optimizer = optim.RMSprop(self.model.parameters(), lr=learning_rate, momentum=momentum,
                                               centered=True)
        else:
            self.optimizer = optimizer
        self.device = torch.device("cuda" if cuda else "cpu")
        self.train_steps = 0
        if self.model is not None:
            self.model.to(self.device)

        self.initialize(load_file)
示例#2
0
    def __exit__(self, exc_type, exc_val, traceback) -> None:
        torch.set_anomaly_enabled(self.prev)
        if self.mode:
            # redirect stderr back to the saved fd
            self._redirect_stderr(self.stderr_fd_copy)
            # Copy contents of temporary file to the given stream
            self.tfile.flush()
            self.tfile.seek(0, io.SEEK_SET)
            self.stream.write(self.tfile.read())
            self.tfile.close()
            os.close(self.stderr_fd_copy)
            stream_value = self.stream.getvalue().decode('utf-8')

            if exc_val:
                raise RuntimeError(stream_value).with_traceback(traceback)
            else:
                log.warning(stream_value)
示例#3
0
 def __init__(self, mode: bool) -> None:
     self.prev = torch.is_anomaly_enabled()
     self.mode = mode
     torch.set_anomaly_enabled(mode)
     if self.mode:
         warnings.warn(
             'Anomaly Detection has been enabled. '
             'This mode will increase the runtime '
             'and should only be enabled for debugging.',
             stacklevel=2)
         self.stream = io.BytesIO()
         # The original fd stderr points to. Usually 2 on POSIX systems.
         self.stderr_fd_origin = sys.stderr.fileno()
         # Make a copy of the original stderr fd in stderr_fd_copy
         self.stderr_fd_copy = os.dup(self.stderr_fd_origin)
         # Create a temporary file and redirect stderr to it
         self.tfile = tempfile.TemporaryFile(mode='w+b')
         self._redirect_stderr(self.tfile.fileno())
示例#4
0
    def __init__(self, model=None, optimizer=None, learning_rate=-1, momentum=-1, cuda=True, load_file=None,
                 save_every=500, save_file=None, max_save_files=5, debug=False):
        self.model = model
        self.save_every = save_every
        self.save_file = save_file
        self.max_save_files = max_save_files
        self.debug = debug
        self.device = torch.device("cuda" if cuda else "cpu")
        self.epoch = 0
        self.train_steps = 0
        if debug:
            torch.set_anomaly_enabled(True)

        self.setup_optimizer(optimizer, learning_rate, momentum)
        if self.model is not None:
            self.model.to(self.device)

        self.initialize(load_file)
示例#5
0
 def __exit__(self, *args: Any) -> None:
     torch.set_anomaly_enabled(self.prev)
示例#6
0
 def __enter__(self) -> None:
     torch.set_anomaly_enabled(True)
示例#7
0
def train(rank, args):
    print('enter train @ %s'%(rank), flush=True)
    args.rank = rank
    args.split = ''
    torch.manual_seed(42)
    save_fn = os.path.join(args.save_dir, 'checkpoint_final.pt')

    tokenizer = get_tokenizer(args)
    args.vocab_size = tokenizer._tokenizer.get_vocab_size() if not args.vocab_size else args.vocab_size
    
    train_dataset = get_dataset(args)
    
    batched_already = hasattr(train_dataset, '__getbatch__')

    if args.total_num_updates < 100:
        args.total_num_updates = len(train_dataset) * args.total_num_updates

    if args.warmup_updates < 1:
        args.warmup_updates = int(args.total_num_updates * args.warmup_updates)
    else:
        args.warmup_updates = int(args.warmup_updates)
    
    train_sampler = None
    if args.gpus:
        dist.init_process_group(
            'nccl', 
            rank=rank, 
            world_size=args.world_size
        )
        if args.gpus > 1:
            train_sampler = torch.utils.data.distributed.DistributedSampler(
                train_dataset,
                num_replicas=args.gpus,
                rank=rank,
                shuffle=args.shuffle)

    else:
        rank = xm.get_ordinal()
        if xm.xrt_world_size() > 1:
            train_sampler = torch.utils.data.distributed.DistributedSampler(
                train_dataset,
                num_replicas=xm.xrt_world_size(),
                rank=rank,
                shuffle=args.shuffle)


    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=args.batch_size if not batched_already else None,
        sampler=train_sampler,
        pin_memory=True,
        shuffle=False,
        num_workers=args.num_workers)
        

    eval_loaders = []
    if args.eval_dir:
        for split in args.splits.split(','):
            split = split.strip()
            eval_sampler = None
            if args.gpus:
                if args.gpus > 1:
                    eval_sampler = torch.utils.data.distributed.DistributedSampler(
                        train_dataset,
                        num_replicas=args.gpus,
                        rank=rank,
                        shuffle=False)

            else:
                rank = xm.get_ordinal()
                if xm.xrt_world_size() > 1:
                    eval_sampler = torch.utils.data.distributed.DistributedSampler(
                        train_dataset,
                        num_replicas=xm.xrt_world_size(),
                        rank=rank,
                        shuffle=False)

            args.split = split
            eval_dataset = get_eval_dataset(args)
            eval_loader = torch.utils.data.DataLoader(
                eval_dataset,
                batch_size=args.batch_size if not batched_already else None,
                sampler=eval_sampler,
                pin_memory=True,
                shuffle=False,
                num_workers=args.num_workers)
            eval_loaders.append(eval_loader)

    if args.gpus:
        assert apex_enabled
        torch.cuda.set_device(rank)


        ##########################
        ##
        ##  Model Creation
        ##
        ##########################
        model = get_model(args, tokenizer)

        model.cuda(rank)

        device = torch.device('cuda:'+str(rank))

        ##########################
        ##
        ##  Init Optimizer
        ##
        ##########################

        optimizer = apex.optimizers.FusedAdam(
            model_get_parameters(model,
                                 lr=args.lr,
                                 lw_lr_decay=args.lw_lr_decay,
                                 weight_decay=args.weight_decay,
                                 special_layer_wise_lr=args.special_layer_wise_lr,
                                 log = rank == 0,
                                 ),  

                                 # use this function to set extra optimizer arguments, 
                                 # see model_get_parameters
            betas=(0.9, 0.999), 
            eps=1e-6,
            lr=args.lr, 
            weight_decay=args.weight_decay
        )




        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
        model = DDP(model)
        batches = train_loader

    else:
        assert tpu_enabled
        device = xm.xla_device()


        ##########################
        ##
        ##  Model Creation
        ##
        ##########################
        
        model = get_model(args, tokenizer)


        ##########################
        ##
        ##  For shared parameters, TPU requires modules to be tied after .to(device)
        ##  So we first find the shared parameters first
        ##
        ##########################

        shared_parameters = {e[0]: e[1:] for e in _catalog_shared_params(model)}

        model.to(device)
        
        do_share_parameters_again(model, shared_parameters, log = rank == 0)



        ##########################
        ##
        ##  Init Optimizer
        ##
        ##########################

        optimizer = optim.Adam(
            model_get_parameters(model,
                                 lr=args.lr,
                                 lw_lr_decay=args.lw_lr_decay,
                                 weight_decay=args.weight_decay
                                 ),  

                                 # use this function to set extra optimizer arguments, 
                                 # see model_get_parameters
            lr=args.lr,
            weight_decay=args.weight_decay
        )


        writer = None
        if xm.is_master_ordinal():
            writer = test_utils.get_summary_writer(args.save_dir)
                
        xm.rendezvous("load_checkpoint")  # wait for all workers
        xm.mark_step()

        # tracker = xm.RateTracker()
        
        
        
    if args.restore_file:
        states = torch.load(args.restore_file, map_location=device)
        for k, v in list(states.items()):
            if k.startswith('module.'):
                del states[k]
                k = k[7:]
                states[k] = v
            if k.endswith('position_ids'):
                del states[k]
                states[k[:-12] + 'position_embeddings'] = v
                
        if args.gpus:
            states = {"module.%s"%k : v for k, v in states.items()}
        try:
            model.load_state_dict(states)
        except Exception as err:
            import traceback
            if rank == 0:
                traceback.print_exc()
            model.load_state_dict(states, strict=False)
            
        
    if rank == 0:
        if not os.path.exists(os.path.dirname(save_fn)):
            try:
                os.makedirs(os.path.dirname(save_fn))
            except OSError as exc: # Guard against race condition
                if exc.errno != errno.EEXIST:
                    raise
        if args.gpus:
            torch.save(model.state_dict(), save_fn )
        else:
            xm.save(model.state_dict(), save_fn )
        
    model.train()

    if args.anomaly_detection and rank == 0:
        torch.set_anomaly_enabled(True)

    ##########################
    ##
    ##  Init LR Scheduler
    ##
    ##########################
    
    if not batched_already:
        args.total_num_updates = args.total_num_updates // args.batch_size
        args.warmup_updates = args.total_num_updates // args.batch_size
        
        
    args.total_num_updates = args.total_num_updates // args.world_size
    args.warmup_updates = args.total_num_updates // args.world_size

    scheduler = get_linear_schedule_with_warmup(
        optimizer, 
        num_warmup_steps=args.warmup_updates, 
        num_training_steps=args.total_num_updates, 
    )

    step_i = 0

    err = None
    tb = None
    #tb = SummaryWriter()
    try:
        if rank == 0:
            pbar = tqdm(total=args.total_num_updates, file=sys.stdout)
        while step_i < args.total_num_updates:
            if not args.gpus:
                batches = pl.ParallelLoader(train_loader, [device]).per_device_loader(device)
                
            n_samples = len(batches)
                
            for sample in batches:
                step_i += 1
                if step_i > args.total_num_updates:
                    break

                report_step = step_i % args.log_interval == 0

                while True: # the loop only for apex Gradient Overflow
                    optimizer.zero_grad()

                    total_loss, log = get_loss(
                        model, 
                        sample, 
                        args=args, 
                        device=device, 
                        gpus=args.gpus, 
                        report=report_step
                    )

                    if args.gpus:
                        default_optimizer_step = optimizer.step

                        with amp.scale_loss(total_loss, optimizer) as scaled_loss:
                            scaled_loss.backward()

                        # If Amp detects an overflow, it patches optimizer.step.  In other words, if optimizer.step
                        # was left unpatched, there was no overflow, and we don't need to replay.
                        if optimizer.step is default_optimizer_step:
                            optimizer.step()
                            break

                        optimizer.step() # If an overflow was detected, "optimizer.step" is the patched call, which does 
                                         # nothing but restore optimizer.step to default_optimizer_step.
                        if rank == 0:
                            print("Overflowed, reducing loss scale and replaying batch.", flush=True)
                        
                    else:
                        total_loss.backward()
                        xm.optimizer_step(optimizer)
                        xm.mark_step()

                        break



                scheduler.step()

                if report_step:
                    if 'loss' not in log:
                        log['loss'] = total_loss

                    # tb.add_scalar("Loss", total_loss, step_i)

                    for k, v in log.items():
                        try:
                            dist.all_reduce(v, op=dist.reduce_op.SUM)
                            log[k] = float(v)
                        except Exception as e:
                            print(v, e)
                            pass
                        
                    if args.gpus:
                        if rank == 0:
                            pbar.set_description(format_log(log, log_formatter, tb, step_i))
                    else:
                        xm.add_step_closure(_train_update, args=(log, log_formatter, tb, step_i))

                    if args.report_metrics:
                        xm.master_print(met.metrics_report())

                
                if rank == 0:
                    pbar.update(1)

        if rank == 0:
            pbar.close()
        if eval_loaders:
            model.half()
            model.eval()
            model.cuda()
            for k, v in model.named_parameters():
                v.requires_grad =False

                
            for split, eval_loader in zip(args.splits.split(','), eval_loaders):
                batches = eval_loader
                if rank == 0:
                    eval_length = len(batches)
                    if not batched_already:
                        eval_length = eval_length // args.batch_size

                    eval_length = eval_length // args.world_size

                    pbar = tqdm(total=eval_length, file=sys.stdout)
                
                if not args.gpus:
                    batches = pl.ParallelLoader(eval_loader, [device]).per_device_loader(device)
                with torch.no_grad():
                    record = OrderedDict()

                    for sample in batches:
                        evaluate(
                            model, 
                            sample, 
                            args=args, 
                            device=device, 
                            record=record,
                            gpus=args.gpus, 
                            report=False
                        )
                        if rank == 0:
                            pbar.update(1)

                    for k, v in record.items():
                        try:
                            def handle_reduce(v):
                                if len(v.shape) == 0:
                                    dist.all_reduce(v, op=dist.reduce_op.SUM)
                                else:
                                    L = [torch.ones_like(v) for _ in range(dist.get_world_size())]
                                    dist.all_gather(L, v)
                                    v = torch.car(L, dim=0)
                                return v
                            if isinstance(v, list):
                                v = [handle_reduce(e) for e in v]
                            else:
                                v = handle_reduce(v)
                            record[k] = float(v)
                        except Exception as e:
                            pass

                    post_evaluate(record, args=args)

                import json

                if rank == 0:
                    print('',flush=True)
                    print('Test result for %s'%split, flush=True)
                    print(json.dumps(record, indent=2),flush=True)
                    print('',flush=True)


    except Exception as _err:
        err = _err
    finally:
        folder = os.path.split(os.path.abspath(save_fn))[0]
        os.makedirs(folder, exist_ok=True)
        if rank == 0:
            print("Saving to %s"%save_fn)
            if args.gpus:
                torch.save(model.state_dict(), save_fn )
                if err:
                    raise err
            else:
                xm.save(model.state_dict(), save_fn )
                if err:
                    raise err
            print("Saved to %s"%save_fn)
示例#8
0
 def __enter__(self) -> None:
     torch.set_anomaly_enabled(True, self.check_nan)
示例#9
0
    def bptt_train(self, epochs=1, print_interval=1, k1=1, k2=3, seq=5, skip=None, debug=False):
        cur_time = prev_time = time.time()
        torch.set_anomaly_enabled(debug)
        retain_graph = k1 < k2

        for epoch in range(1, epochs + 1):
            self.model.train()

            net_state_hidden, net_cell_states = self.init_rnn_state()
            epoch_loss_collector = LossCollector()
            states = [(None, self.init_rnn_state())]

            total_loss = 0
            for i, data in enumerate(self.train_data):
                if skip and i % skip != 0:
                    continue

                if data.x is None or torch.isnan(data.x).any():
                    # print("Skipping", i, data.x, torch.isnan(data.x).any(), data.x is None)
                    continue

                data = data.to(self.device)
                self.optimizer.zero_grad()

                # Compute gradients
                outputs, expected, net_state_hidden, net_cell_states = self.feed(data, net_state_hidden, net_cell_states)
                
                def repackage_hidden(h):
                    """Wraps hidden states in new Tensors, to detach them from their history."""

                    if isinstance(h, torch.Tensor):
                        return h.detach()
                    else:
                        return tuple(repackage_hidden(v) for v in h)

                mask = data.current_transports[:, -self.pred_seq_len:]
                outputs[~mask] = 0
                expected[~mask] = 0

                l1_regularization = 0.
                for param in self.model.parameters():
                    l1_regularization = l1_regularization + param.abs().sum()
                
                loss = self.loss(outputs, expected)

                loss = loss + self.l1_reg * l1_regularization

                total_loss = total_loss + loss

                total_norm = 0
                for p in self.model.parameters():
                    if p.grad is None:
                        continue
                    param_norm = p.grad.data.norm(2)
                    total_norm = total_norm + param_norm.item() ** 2
                total_norm = total_norm ** (1. / 2)
                # print("Paramm norm:", total_norm)

                if i % seq == 0:
                    total_loss.backward()
                    epoch_loss_collector.collect(outputs, expected)
                    # Update parameters
                    nn.utils.clip_grad_norm_(self.model.parameters(), 1)
                    self.optimizer.step()

                    # Cut the gradient graph
                    net_state_hidden = repackage_hidden(net_state_hidden)
                    net_cell_states = repackage_hidden(net_cell_states)
                    total_loss = 0

            loss = epoch_loss_collector.reduce()
            self.collect_train_metrics(loss)

            cur_time = time.time()
            log = "Epoch: {:03d}, Train: {:.4f} ({:.4f}) Val: {:.4f} ({:.4f}) Acc: {:.4f}, Time: {:.1f}s"
            if epoch % print_interval == 0:
                val_loss = dict(mse=-1, mae=-1, acc=-1)
                val_acc = [-1]
                if epoch % val_interval == 0:
                    # Validate, this uses another hidden state for the model
                    val_acc, val_loss = self.test()
                print(
                    log.format(
                        epoch,
                        loss["mse"], loss["mae"],
                        val_loss["mse"], val_loss["mae"],
                        val_loss["acc"],
                        cur_time - prev_time,
                    )
                )
            prev_time = cur_time

        self.print_eval_summary()
示例#10
0
 def __init__(self, mode):
     self.prev = torch.is_anomaly_enabled()
     torch.set_anomaly_enabled(mode)
示例#11
0
import torch.nn.functional as F
import torch.optim as optim
import matplotlib
from torch.nn.utils import clip_grad_norm_
from collections import OrderedDict
import os
from rational.torch import Rational, RecurrentRational
from torchvision import datasets, transforms
from torch.utils.tensorboard import SummaryWriter

font = {'family': 'normal', 'weight': 'bold', 'size': 22}

writer = None
matplotlib.rc('font', **font)

torch.set_anomaly_enabled(True)
cnt = 0

actfvs = dict()

actfvs["pau"] = Rational
actfvs["relu"] = nn.ReLU

#_rational = Rational()
#def shared_Rational():
#    return RecurrentRational(_rational)

actfvs["recurrent_pau"] = RecurrentRational()


def vgg_block(num_convs, in_channels, num_channels, actv_function):
示例#12
0
 def __init__(self, mode):
     self.prev = torch.is_anomaly_enabled()
     torch.set_anomaly_enabled(mode)
示例#13
0
 def __exit__(self, *args: Any) -> bool:
     torch.set_anomaly_enabled(self.prev)
     return False
示例#14
0
 def __init__(self, mode: bool) -> None:
     self.prev = torch.is_anomaly_enabled()
     torch.set_anomaly_enabled(mode)
示例#15
0
 def __exit__(self, *args):
     torch.set_anomaly_enabled(self.prev)
     return False
示例#16
0
 def unknown_op(x):
     torch.set_anomaly_enabled(True)
     return x
示例#17
0
def train(rank, args):
    print('enter train @ %s' % (rank), flush=True)
    args.rank = rank
    torch.manual_seed(42)

    tokenizer = get_tokenizer(args)
    args.vocab_size = tokenizer._tokenizer.get_vocab_size()

    train_dataset = get_dataset(args)

    if args.total_num_updates < 100:
        args.total_num_updates = len(train_dataset) * args.total_num_updates

    if args.warmup_updates < 1:
        args.warmup_updates = int(args.total_num_updates * args.warmup_updates)
    else:
        args.warmup_updates = int(args.warmup_updates)

    train_sampler = None
    if args.gpus:
        dist.init_process_group('nccl', rank=rank, world_size=args.world_size)
        if args.gpus > 1:
            train_sampler = torch.utils.data.distributed.DistributedSampler(
                train_dataset,
                num_replicas=args.gpus,
                rank=rank,
                shuffle=False)

    else:
        rank = xm.get_ordinal()
        if xm.xrt_world_size() > 1:
            train_sampler = torch.utils.data.distributed.DistributedSampler(
                train_dataset,
                num_replicas=xm.xrt_world_size(),
                rank=rank,
                shuffle=False)

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=args.batch_size
        if not hasattr(train_dataset, '__getbatch__') else None,
        sampler=train_sampler,
        pin_memory=True,
        shuffle=False,
        num_workers=args.num_workers)

    eval_loader = None
    if args.eval_dir:

        eval_sampler = None
        if args.gpus:
            dist.init_process_group('nccl',
                                    rank=rank,
                                    world_size=args.world_size)
            if args.gpus > 1:
                traieval_samplern_sampler = torch.utils.data.distributed.DistributedSampler(
                    train_dataset,
                    num_replicas=args.gpus,
                    rank=rank,
                    shuffle=False)

        else:
            rank = xm.get_ordinal()
            if xm.xrt_world_size() > 1:
                eval_sampler = torch.utils.data.distributed.DistributedSampler(
                    train_dataset,
                    num_replicas=xm.xrt_world_size(),
                    rank=rank,
                    shuffle=False)

        eval_dataset = get_eval_dataset(args)
        eval_loader = torch.utils.data.DataLoader(
            eval_dataset,
            batch_size=args.batch_size
            if not hasattr(train_dataset, '__getbatch__') else None,
            sampler=eval_sampler,
            pin_memory=True,
            shuffle=False,
            num_workers=args.num_workers)

    if args.gpus:
        assert apex_enabled
        torch.cuda.set_device(rank)

        ##########################
        ##
        ##  Model Creation
        ##
        ##########################
        model = get_model(args)

        model.cuda(rank)

        device = torch.device('cuda:' + str(rank))

        ##########################
        ##
        ##  Init Optimizer
        ##
        ##########################

        optimizer = apex.optimizers.FusedAdam(
            model_get_parameters(model,
                                 lr=args.lr,
                                 lw_lr_decay=args.lw_lr_decay,
                                 weight_decay=args.weight_decay),

            # use this function to set extra optimizer arguments,
            # see model_get_parameters
            betas=(0.9, 0.999),
            eps=1e-6,
            lr=args.lr,
            weight_decay=args.weight_decay)

        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
        model = DDP(model)
        batches = train_loader

    else:
        assert tpu_enabled
        device = xm.xla_device()

        ##########################
        ##
        ##  Model Creation
        ##
        ##########################

        model = get_model(args)

        ##########################
        ##
        ##  For shared parameters, TPU requires modules to be tied after .to(device)
        ##  So we first find the shared parameters first
        ##
        ##########################

        shared_parameters = {
            e[0]: e[1:]
            for e in _catalog_shared_params(model)
        }

        model.to(device)

        do_share_parameters_again(model, shared_parameters, log=rank == 0)

        ##########################
        ##
        ##  Init Optimizer
        ##
        ##########################

        optimizer = optim.Adam(
            model_get_parameters(model,
                                 lr=args.lr,
                                 lw_lr_decay=args.lw_lr_decay,
                                 weight_decay=args.weight_decay),

            # use this function to set extra optimizer arguments,
            # see model_get_parameters
            lr=args.lr,
            weight_decay=args.weight_decay)

        writer = None
        if xm.is_master_ordinal():
            writer = test_utils.get_summary_writer(args.save_dir)

        xm.rendezvous("load_checkpoint")  # wait for all workers
        xm.mark_step()

        # tracker = xm.RateTracker()
    if args.restore_file:
        states = torch.load(args.restore_file, map_location=device)
        for k, v in list(states.items()):
            if k.startswith('module.'):
                del states[k]
                k = k[7:]
                states[k] = v
            if k.endswith('position_ids'):
                del states[k]
                states[k[:-12] + 'position_embeddings'] = v
        try:
            model.load_state_dict(states)
        except Exception as err:
            import traceback
            traceback.print_exc()
            model.load_state_dict(states, strict=False)

    model.train()

    if args.anomaly_detection and rank == 0:
        torch.set_anomaly_enabled(True)

    ##########################
    ##
    ##  Init LR Scheduler
    ##
    ##########################

    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=args.warmup_updates,
        num_training_steps=args.total_num_updates,
    )

    step_i = 0

    err = None
    try:
        if rank == 0:
            pbar = tqdm(total=args.total_num_updates)
        while step_i < args.total_num_updates:
            if not args.gpus:
                batches = pl.ParallelLoader(train_loader,
                                            [device]).per_device_loader(device)
            for sample in batches:
                step_i += 1
                if step_i > args.total_num_updates:
                    break

                report_step = step_i % args.log_interval == 0

                while True:  # the loop only for apex Gradient Overflow
                    optimizer.zero_grad()

                    total_loss, log = get_loss(model,
                                               sample,
                                               args=args,
                                               device=device,
                                               gpu=args.gpus,
                                               report=report_step)

                    if args.gpus:
                        default_optimizer_step = optimizer.step

                        with amp.scale_loss(total_loss,
                                            optimizer) as scaled_loss:
                            scaled_loss.backward()

                        # If Amp detects an overflow, it patches optimizer.step.  In other words, if optimizer.step
                        # was left unpatched, there was no overflow, and we don't need to replay.
                        if optimizer.step is default_optimizer_step:
                            optimizer.step()
                            break

                        optimizer.step(
                        )  # If an overflow was detected, "optimizer.step" is the patched call, which does
                        # nothing but restore optimizer.step to default_optimizer_step.
                        if rank == 0:
                            print(
                                "Overflowed, reducing loss scale and replaying batch.",
                                flush=True)

                    else:
                        total_loss.backward()
                        xm.optimizer_step(optimizer)
                        xm.mark_step()

                        break

                scheduler.step()

                if report_step:
                    if 'loss' not in log:
                        log['loss'] = total_loss

                    if args.gpus:
                        if rank == 0:
                            pbar.set_description(format_log(
                                log, log_formatter))
                    else:
                        xm.add_step_closure(_train_update,
                                            args=(log, log_formatter))

                    if args.report_metrics:
                        xm.master_print(met.metrics_report())

                if rank == 0:
                    pbar.update(1)

        if eval_loader is not None:
            model.eval()
            if not args.gpus:
                batches = pl.ParallelLoader(eval_loader,
                                            [device]).per_device_loader(device)
            with torch.no_grad():
                record = OrderedDict()

                for sample in batches:
                    evaluate(model,
                             sample,
                             args=args,
                             device=device,
                             record=record,
                             gpu=args.gpus,
                             report=report_step)

                post_evaluate(record, args=args)

            import json
            print('', flush=True)
            print(json.dumps(record), flush=True)
            print('', flush=True)

    except Exception as _err:
        err = _err
    finally:
        save_fn = os.path.join(args.save_dir, 'checkpoint_final.pt')
        folder = os.path.split(os.path.abspath(save_fn))[0]
        os.makedirs(folder, exist_ok=True)
        if rank == 0 and args.gpus:
            torch.save(model.state_dict(), save_fn)
            if err:
                raise err
        else:
            xm.save(model.state_dict(), save_fn)
            if err:
                raise err
示例#18
0
 def __enter__(self):
     torch.set_anomaly_enabled(True)
示例#19
0
 def __init__(self, mode: bool, check_nan: bool = True) -> None:
     self.prev = torch.is_anomaly_enabled()
     self.prev_check_nan = torch.is_anomaly_check_nan_enabled()
     torch.set_anomaly_enabled(mode, check_nan)
示例#20
0
 def __exit__(self, *args):
     torch.set_anomaly_enabled(self.prev)
     return False
示例#21
0
    def train(self, epochs=1, print_interval=1, val=True, debug=False):
        cur_time = prev_time = time.time()
        torch.set_anomaly_enabled(debug)

        net_state_hidden, net_cell_states = self.init_rnn_state()

        for epoch in range(1, epochs + 1):
            self.model.train()

            epoch_loss_collector = LossCollector()

            for i, data in enumerate(self.train_data):
                # print("%d of %d" % (i, len(self.train_data)))

                if data.x is None or torch.isnan(data.x).any():
                    # print("Skipping", i, data.x, torch.isnan(data.x).any(), data.x is None)
                    continue

                data = data.to(self.device)
                self.optimizer.zero_grad()

                # Compute gradients
                outputs, expected, net_state_hidden, net_cell_states = self.feed(data, net_state_hidden, net_cell_states)
                mask = data.current_transports[:, -self.pred_seq_len :]
                # print(outputs.shape, expected.shape)
                outputs[~mask] = 0
                expected[~mask] = 0
                
                l1_regularization = 0.
                if True:
                    for param in self.model.parameters():
                        l1_regularization = l1_regularization + param.abs().sum()
                
                assert not torch.isnan(expected).any()
                loss = self.loss(outputs, expected)
                
                loss = loss + self.l1_reg * l1_regularization

                loss.backward()

                epoch_loss_collector.collect(outputs, expected)

                total_norm = 0
                for p in self.model.parameters():
                    param_norm = p.grad.data.norm(2)
                    total_norm = total_norm + param_norm.item() ** 2
                total_norm = total_norm ** (1. / 2)
                assert total_norm == total_norm
                # print("Param norm:", total_norm)

                # Update parameters
                # nn.utils.clip_grad_norm_(self.model.parameters(), 10e0)
                self.optimizer.step()

            if self.lr_scheduler:
                self.lr_scheduler.step()
            loss = epoch_loss_collector.reduce()
            self.collect_train_metrics(loss)

            cur_time = time.time()
            log = "Epoch: {:03d}, Train: {:.4f} ({:.4f}) Val: {:.4f} ({:.4f}) Acc: {:.4f}, Time: {:.1f}s"
            if epoch % print_interval == 0:
                val_loss = dict(mse=-1, mae=-1, acc=-1)
                if val and epoch % val_interval == 0:
                    # Validate
                    val_acc, val_loss = self.test()
                print(
                    log.format(
                        epoch,
                        loss["mse"], loss["mae"],
                        val_loss["mse"], val_loss["mae"],
                        val_loss["acc"],
                        cur_time - prev_time,
                    )
                )
            prev_time = cur_time

        self.print_eval_summary()