class LRFinder():
    def __init__(self, optim, num_iter, low_lr=1e-7, high_lr=10):
        self.optim = optim
        self.num_iter = num_iter
        self.lambda_func = lambda batch: get_running_factor(
            low_lr, high_lr, self.num_iter, batch)
        self.scheduler = LambdaLR(optim, lr_lambda=self.lambda_func)
        self.learning_rates = []
        self.losses = []

    def find(self, model, input, loss_fn):
        smooth_loss = None
        best_loss = np.Inf
        stop_training = False
        num_epochs = int(np.floor(self.num_iter / len(input)))
        for epoch in list(range(num_epochs)):
            if stop_training:
                break
            model.train()
            self.optim.zero_grad()
            for i, batch in enumerate(input):
                print("Batch {} has learning rate {}".format(
                    i, self.scheduler.get_lr()))
                images, masks = batch
                x = Variable(images)
                y = Variable(masks)
                pred = model.forward(x)

                loss = loss_fn(pred, y)
                if smooth_loss is not None:
                    smooth_loss = ewma(smooth_loss, loss.item())
                else:
                    smooth_loss = loss.item()
                if smooth_loss < best_loss:
                    best_loss = smooth_loss
                if smooth_loss > 4 * best_loss or np.isnan(smooth_loss):
                    stop_training = True
                    break

                self.learning_rates.append(self.scheduler.get_lr())
                self.losses.append(loss.item())

                loss.backward()
                self.optim.step()
                self.scheduler.step()
                self.optim.zero_grad()
        self.learning_rates = np.array(self.learning_rates).flatten()
        self.losses = np.array(self.losses)

    def get_learning_rates(self):
        return self.learning_rates

    def get_losses(self):
        return self.losses
def train(classifier, train_loader, test_loader, args):
    optimizer = torch.optim.SGD(classifier.parameters(),
                                lr=args.learning_rate,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay,
                                nesterov=True)

    best_train_loss = np.inf
    scheduler = LambdaLR(
        optimizer,
        lr_lambda=lambda step: get_lr(  # pylint: disable=g-long-lambda
            step,
            args.epochs * len(train_loader),
            1,  # lr_lambda computes multiplicative factor
            1e-6 / args.learning_rate))

    for epoch in range(1, args.epochs + 1):
        train_loss, train_acc = run_epoch(classifier,
                                          train_loader,
                                          args,
                                          optimizer=optimizer,
                                          scheduler=scheduler)
        lr = scheduler.get_lr()[0]
        logger.info(
            'Epoch: {}, lr: {:.4f}, training loss: {:.4f}, acc: {:.4f}.'.
            format(epoch, lr, train_loss, train_acc))

        test_loss, test_acc = run_epoch(classifier, test_loader, args)
        logger.info("Test loss: {:.4f}, acc: {:.4f}".format(
            test_loss, test_acc))

        if train_loss < best_train_loss:
            best_train_loss = train_loss
            save_name = 'resnet18_wd{}.pth'.format(args.weight_decay)
            state = classifier.state_dict()

            torch.save(state, save_name)
            logger.info(
                "==> New optimal training loss & saving checkpoint ...")
Beispiel #3
0
class Trainer(object):
    ''' An object that encapsulates model training '''
    def __init__(self, config, model, dataloader, device):
        self.model = model
        self.config = config
        self.device = device
        self.stopped_early = False
        self.dataloader = dataloader
        self.validation_dataloader = dataloader
        self.last_checkpoint_time = time.time()

        if 'cuda' in device.type:
            self.model = nn.DataParallel(model.cuda())

        if self.config.optimizer == "adam":
            self.optimizer = optim.Adam(model.parameters(),
                                        config.base_lr,
                                        betas=(0.9, 0.98),
                                        eps=1e-9)
            # self.optimizer = optim.Adam(model.parameters(), 1e-7, betas=(0.9, 0.98), eps=1e-9)
            if config.lr_scheduler == 'warmup':
                self.lr_scheduler = LambdaLR(
                    self.optimizer, WarmupLRSchedule(config.warmup_steps))

            elif config.lr_scheduler == 'warmup2':
                self.lr_scheduler = LambdaLR(
                    self.optimizer, WarmupLRSchedule2(config.warmup_steps))

            elif config.lr_scheduler == 'linear':
                self.lr_scheduler = LambdaLR(
                    self.optimizer,
                    LinearLRSchedule(config.base_lr, config.final_lr,
                                     config.max_steps))
            elif config.lr_scheduler == 'exponential':
                self.lr_scheduler = ExponentialLR(self.optimizer,
                                                  config.lr_decay)
            else:
                raise ValueError('Unknown learning rate scheduler!')

        elif self.config.optimizer == "sgd":
            print("using optimizer: SGD")
            self.optimizer = optim.SGD(model.parameters(),
                                       lr=config.base_lr,
                                       momentum=0.9)
            self.lr_scheduler = LambdaLR(self.optimizer,
                                         DummyLRSchedule(config.base_lr))

        elif self.config.optimizer == "adam-fixed":
            print("using optimizer: adam with fixed learning rate")
            self.optimizer = optim.Adam(model.parameters(),
                                        config.base_lr,
                                        betas=(0.9, 0.98),
                                        eps=1e-9)
            self.lr_scheduler = LambdaLR(self.optimizer,
                                         DummyLRSchedule(config.base_lr))

        else:
            raise ValueError('Unknown optimizer!')

        # Initialize the metrics
        metrics_path = os.path.join(self.config.checkpoint_directory,
                                    'train_metrics.pt')
        self.metric_store = metrics.MetricStore(metrics_path)
        self.metric_store.add(metrics.Metric('oom', metrics.format_int, 't'))
        self.metric_store.add(
            metrics.Metric('nll', metrics.format_float, max_history=1000))
        self.metric_store.add(
            metrics.Metric('lr', metrics.format_scientific, 'g',
                           max_history=1))
        self.metric_store.add(
            metrics.Metric('num_tok',
                           metrics.format_int,
                           'a',
                           max_history=1000))
        # self.metric_store.add(metrics.Metric('time_per_batch', metrics.format_float, 'g', max_history=100000))
        # self.metric_store.add(metrics.Metric('time_total', metrics.format_float, 'g', max_history=1))

        if self.config.early_stopping:
            self.metric_store.add(
                metrics.Metric('vnll', metrics.format_float, 'g'))

        self.modules = {
            'model': model,
            'optimizer': self.optimizer,
            'lr_scheduler': self.lr_scheduler
        }

    @property
    def dataset(self):
        ''' Get the dataset '''
        return self.dataloader.dataset

    def train_epoch(self, epoch, experiment, verbose=0):
        ''' Run one training epoch '''
        oom = self.metric_store['oom']
        learning_rate = self.metric_store['lr']
        num_tokens = self.metric_store['num_tok']
        neg_log_likelihood = self.metric_store['nll']

        def try_optimize(i, last=False):
            # optimize if:
            #  1) last and remainder
            #  2) not last and not remainder
            remainder = bool(i % self.config.accumulate_steps)
            if not last ^ remainder:
                next_lr = self.optimize()

                learning_rate.update(next_lr)
                experiment.log_metric('learning_rate', next_lr)
                return True

            return False

        def get_description():
            description = f'Train #{epoch}'
            if verbose > 0:
                description += f' {self.metric_store}'
            if verbose > 1:
                description += f' [{profile.mem_stat_string(["allocated"])}]'
            return description

        batches = tqdm(
            self.dataloader,
            unit='batch',
            dynamic_ncols=True,
            desc=get_description(),
            file=sys.stdout  # needed to make tqdm_wrap_stdout work
        )
        with tqdm_wrap_stdout():
            i = 1
            nll_per_update = 0.
            length_per_update = 0
            num_tokens_per_update = 0
            for i, batch in enumerate(batches, 1):

                try:

                    nll, length = self.calculate_gradient(batch)
                    did_optimize = try_optimize(i)

                    # record the effective number of tokens
                    num_tokens_per_update += int(sum(batch['input_lens']))
                    num_tokens_per_update += int(sum(batch['target_lens']))

                    if length:
                        # record length and nll
                        nll_per_update += nll
                        length_per_update += length

                    if did_optimize:
                        # advance the experiment step
                        experiment.set_step(experiment.curr_step + 1)

                        num_tokens.update(num_tokens_per_update)
                        neg_log_likelihood.update(nll_per_update /
                                                  length_per_update)

                        experiment.log_metric('num_tokens',
                                              num_tokens_per_update)
                        experiment.log_metric('nll',
                                              neg_log_likelihood.last_value)
                        # experiment.log_metric('max_memory_alloc', torch.cuda.max_memory_allocated()//1024//1024)
                        # experiment.log_metric('max_memory_cache', torch.cuda.max_memory_cached()//1024//1024)

                        nll_per_update = 0.
                        length_per_update = 0
                        num_tokens_per_update = 0

                except RuntimeError as rte:
                    if 'out of memory' in str(rte):
                        torch.cuda.empty_cache()

                        oom.update(1)
                        experiment.log_metric('oom', oom.total)
                        #exit(-1)
                    else:
                        batches.close()
                        raise rte

                if self.should_checkpoint():
                    new_best = False
                    if self.config.early_stopping:
                        with tqdm_unwrap_stdout():
                            new_best = self.evaluate(experiment, epoch,
                                                     verbose)

                    self.checkpoint(epoch, experiment.curr_step, new_best)

                batches.set_description_str(get_description())
                if self.is_done(experiment, epoch):
                    batches.close()
                    break

            try_optimize(i, last=True)

    def should_checkpoint(self):
        ''' Function which determines if a new checkpoint should be saved '''
        return time.time(
        ) - self.last_checkpoint_time > self.config.checkpoint_interval

    def checkpoint(self, epoch, step, best=False):
        ''' Save a checkpoint '''
        checkpoint_path = checkpoint(
            epoch,
            step,
            self.modules,
            self.config.checkpoint_directory,
            max_checkpoints=self.config.max_checkpoints)

        if best:
            dirname = os.path.dirname(checkpoint_path)
            basename = os.path.basename(checkpoint_path)
            best_checkpoint_path = os.path.join(dirname, f'best_{basename}')
            shutil.copy2(checkpoint_path, best_checkpoint_path)

        self.metric_store.save()
        self.last_checkpoint_time = time.time()

    def evaluate(self, experiment, epoch, verbose=0):
        ''' Evaluate the current model and determine if it is a new best '''
        model = self.modules['model']
        evaluator = Evaluator(args.ArgGroup(None), model,
                              self.validation_dataloader, self.device)
        vnll = evaluator(epoch, experiment, verbose)
        metric = self.metric_store['vnll']
        full_history = metric.values
        metric.update(vnll)
        self.metric_store.save()

        return all(vnll < nll for nll in full_history[:-1])

    def is_done(self, experiment, epoch):
        ''' Has training completed '''
        if self.config.max_steps and experiment.curr_step >= self.config.max_steps:
            return True

        if self.config.max_epochs and epoch >= self.config.max_epochs:
            return True

        if self.config.early_stopping:
            history = self.metric_store['vnll'].values[
                -self.config.early_stopping - 1:]
            if len(history) == self.config.early_stopping + 1:
                self.stopped_early = all(history[-1] > nll
                                         for nll in history[:-1])
                return self.stopped_early

        return False

    def optimize(self):
        ''' Calculate an optimization step '''
        self.optimizer.step()
        self.optimizer.zero_grad()
        self.lr_scheduler.step()

        return self.lr_scheduler.get_lr()[0]

    def calculate_gradient(self, batch):
        ''' Runs one step of optimization '''
        # run the data through the model
        self.model.train()
        loss, nll = self.model(batch)

        # nn.DataParallel wants to gather rather than doing a reduce_add, so the output here
        # will be a tensor of values that must be summed
        nll = nll.sum()
        loss = loss.sum()

        # calculate gradients then run an optimization step
        loss.backward()

        # need to use .item() which converts to Python scalar
        # because as a Tensor it accumulates gradients
        return nll.item(), torch.sum(batch['target_lens']).item()

    def __call__(self, start_epoch, experiment, verbose=0):
        ''' Execute training '''
        with ExitStack() as stack:
            stack.enter_context(chunked_scattering())
            stack.enter_context(experiment.train())

            if start_epoch > 0 or experiment.curr_step > 0:
                # TODO: Hacky approach to decide if the metric store should be loaded. Revisit later
                self.metric_store = self.metric_store.load()

            epoch = start_epoch
            experiment.log_current_epoch(epoch)
            while not self.is_done(experiment, epoch):
                experiment.log_current_epoch(epoch)
                self.train_epoch(epoch, experiment, verbose)
                experiment.log_epoch_end(epoch)
                epoch += 1

            if self.stopped_early:
                print('Stopping early!')
            else:
                new_best = False
                if self.config.early_stopping:
                    new_best = self.evaluate(experiment, epoch, verbose)

                self.checkpoint(epoch, experiment.curr_step, new_best)
Beispiel #4
0
def main(args: argparse.Namespace):
    logger = CompleteLogger(args.log, args.phase)
    print(args)

    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')

    cudnn.benchmark = True

    # Data loading code
    normalize = T.Normalize(mean=[0.485, 0.456, 0.406],
                            std=[0.229, 0.224, 0.225])
    train_transform = T.Compose(
        [T.Resize(args.resize_size),
         T.ToTensor(), normalize])
    val_transform = T.Compose(
        [T.Resize(args.resize_size),
         T.ToTensor(), normalize])

    dataset = datasets.__dict__[args.data]
    train_source_dataset = dataset(root=args.root,
                                   task=args.source,
                                   split='train',
                                   download=True,
                                   transform=train_transform)
    train_source_loader = DataLoader(train_source_dataset,
                                     batch_size=args.batch_size,
                                     shuffle=True,
                                     num_workers=args.workers,
                                     drop_last=True)
    train_target_dataset = dataset(root=args.root,
                                   task=args.target,
                                   split='train',
                                   download=True,
                                   transform=train_transform)
    train_target_loader = DataLoader(train_target_dataset,
                                     batch_size=args.batch_size,
                                     shuffle=True,
                                     num_workers=args.workers,
                                     drop_last=True)
    val_dataset = dataset(root=args.root,
                          task=args.target,
                          split='test',
                          download=True,
                          transform=val_transform)
    val_loader = DataLoader(val_dataset,
                            batch_size=args.batch_size,
                            shuffle=False,
                            num_workers=args.workers)

    train_source_iter = ForeverDataIterator(train_source_loader)
    train_target_iter = ForeverDataIterator(train_target_loader)

    # create model
    print("=> using pre-trained model '{}'".format(args.arch))
    num_factors = train_source_dataset.num_factors
    backbone = models.__dict__[args.arch](pretrained=True)
    bottleneck_dim = args.bottleneck_dim
    if args.normalization == 'IN':
        backbone = convert_model(backbone)
        bottleneck = nn.Sequential(
            nn.Conv2d(backbone.out_features,
                      bottleneck_dim,
                      kernel_size=3,
                      stride=1,
                      padding=1),
            nn.InstanceNorm2d(bottleneck_dim),
            nn.ReLU(),
        )
        head = nn.Sequential(
            nn.Conv2d(bottleneck_dim,
                      bottleneck_dim,
                      kernel_size=3,
                      stride=1,
                      padding=1), nn.InstanceNorm2d(bottleneck_dim), nn.ReLU(),
            nn.Conv2d(bottleneck_dim,
                      bottleneck_dim,
                      kernel_size=3,
                      stride=1,
                      padding=1), nn.InstanceNorm2d(bottleneck_dim), nn.ReLU(),
            nn.AdaptiveAvgPool2d(output_size=(1, 1)), nn.Flatten(),
            nn.Linear(bottleneck_dim, num_factors), nn.Sigmoid())
        for layer in head:
            if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear):
                nn.init.normal_(layer.weight, 0, 0.01)
                nn.init.constant_(layer.bias, 0)
        adv_head = nn.Sequential(
            nn.Conv2d(bottleneck_dim,
                      bottleneck_dim,
                      kernel_size=3,
                      stride=1,
                      padding=1), nn.InstanceNorm2d(bottleneck_dim), nn.ReLU(),
            nn.Conv2d(bottleneck_dim,
                      bottleneck_dim,
                      kernel_size=3,
                      stride=1,
                      padding=1), nn.InstanceNorm2d(bottleneck_dim), nn.ReLU(),
            nn.AdaptiveAvgPool2d(output_size=(1, 1)), nn.Flatten(),
            nn.Linear(bottleneck_dim, num_factors), nn.Sigmoid())
        for layer in adv_head:
            if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear):
                nn.init.normal_(layer.weight, 0, 0.01)
                nn.init.constant_(layer.bias, 0)
        regressor = ImageRegressor(backbone,
                                   num_factors,
                                   bottleneck=bottleneck,
                                   head=head,
                                   adv_head=adv_head,
                                   bottleneck_dim=bottleneck_dim,
                                   width=bottleneck_dim)
    else:
        regressor = ImageRegressor(backbone,
                                   num_factors,
                                   bottleneck_dim=bottleneck_dim,
                                   width=bottleneck_dim)

    regressor = regressor.to(device)
    print(regressor)
    mdd = MarginDisparityDiscrepancy(args.margin).to(device)

    # define optimizer and lr scheduler
    optimizer = SGD(regressor.get_parameters(),
                    args.lr,
                    momentum=args.momentum,
                    weight_decay=args.wd,
                    nesterov=True)
    lr_scheduler = LambdaLR(
        optimizer, lambda x: args.lr *
        (1. + args.lr_gamma * float(x))**(-args.lr_decay))

    # resume from the best checkpoint
    if args.phase != 'train':
        checkpoint = torch.load(logger.get_checkpoint_path('best'),
                                map_location='cpu')
        regressor.load_state_dict(checkpoint)

    # analysis the model
    if args.phase == 'analysis':
        train_source_loader = DataLoader(train_source_dataset,
                                         batch_size=args.batch_size,
                                         shuffle=True,
                                         num_workers=args.workers,
                                         drop_last=True)
        train_target_loader = DataLoader(train_target_dataset,
                                         batch_size=args.batch_size,
                                         shuffle=True,
                                         num_workers=args.workers,
                                         drop_last=True)
        # extract features from both domains
        feature_extractor = nn.Sequential(regressor.backbone,
                                          regressor.bottleneck,
                                          regressor.head[:-2]).to(device)
        source_feature = collect_feature(train_source_loader,
                                         feature_extractor, device)
        target_feature = collect_feature(train_target_loader,
                                         feature_extractor, device)
        # plot t-SNE
        tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.pdf')
        tsne.visualize(source_feature, target_feature, tSNE_filename)
        print("Saving t-SNE to", tSNE_filename)
        # calculate A-distance, which is a measure for distribution discrepancy
        A_distance = a_distance.calculate(source_feature, target_feature,
                                          device)
        print("A-distance =", A_distance)
        return

    if args.phase == 'test':
        mae = validate(val_loader, regressor, args,
                       train_source_dataset.factors, device)
        print(mae)
        return

    # start training
    best_mae = 100000.
    for epoch in range(args.epochs):
        # train for one epoch
        print("lr", lr_scheduler.get_lr())
        train(train_source_iter, train_target_iter, regressor, mdd, optimizer,
              lr_scheduler, epoch, args)

        # evaluate on validation set
        mae = validate(val_loader, regressor, args,
                       train_source_dataset.factors, device)

        # remember best mae and save checkpoint
        torch.save(regressor.state_dict(),
                   logger.get_checkpoint_path('latest'))
        if mae < best_mae:
            shutil.copy(logger.get_checkpoint_path('latest'),
                        logger.get_checkpoint_path('best'))
        best_mae = min(mae, best_mae)
        print("mean MAE {:6.3f} best MAE {:6.3f}".format(mae, best_mae))

    print("best_mae = {:6.3f}".format(best_mae))

    logger.close()
def main(args: argparse.Namespace):
    logger = CompleteLogger(args.log, args.phase)
    print(args)

    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')

    cudnn.benchmark = True

    # Data loading code
    train_transform = utils.get_train_transform(
        args.train_resizing,
        random_horizontal_flip=not args.no_hflip,
        random_color_jitter=False,
        resize_size=args.resize_size,
        norm_mean=args.norm_mean,
        norm_std=args.norm_std)
    val_transform = utils.get_val_transform(args.val_resizing,
                                            resize_size=args.resize_size,
                                            norm_mean=args.norm_mean,
                                            norm_std=args.norm_std)
    print("train_transform: ", train_transform)
    print("val_transform: ", val_transform)

    train_source_dataset, train_target_dataset, val_dataset, test_dataset, num_classes, args.class_names = \
        utils.get_dataset(args.data, args.root, args.source, args.target, train_transform, val_transform)
    train_source_loader = DataLoader(train_source_dataset,
                                     batch_size=args.batch_size,
                                     shuffle=True,
                                     num_workers=args.workers,
                                     drop_last=True)
    train_target_loader = DataLoader(train_target_dataset,
                                     batch_size=args.batch_size,
                                     shuffle=True,
                                     num_workers=args.workers,
                                     drop_last=True)
    val_loader = DataLoader(val_dataset,
                            batch_size=args.batch_size,
                            shuffle=False,
                            num_workers=args.workers)
    test_loader = DataLoader(test_dataset,
                             batch_size=args.batch_size,
                             shuffle=False,
                             num_workers=args.workers)

    train_source_iter = ForeverDataIterator(train_source_loader)
    train_target_iter = ForeverDataIterator(train_target_loader)

    # create model
    print("=> using model '{}'".format(args.arch))
    backbone = utils.get_model(args.arch, pretrain=not args.scratch)
    pool_layer = nn.Identity() if args.no_pool else None
    classifier = ImageClassifier(backbone,
                                 num_classes,
                                 bottleneck_dim=args.bottleneck_dim,
                                 pool_layer=pool_layer,
                                 finetune=not args.scratch).to(device)
    domain_discri = DomainDiscriminator(in_feature=classifier.features_dim,
                                        hidden_size=1024).to(device)

    # define loss function
    domain_adv = DomainAdversarialLoss().to(device)
    gl = WarmStartGradientLayer(alpha=1.,
                                lo=0.,
                                hi=1.,
                                max_iters=1000,
                                auto_step=True)

    # define optimizer and lr scheduler
    optimizer = SGD(classifier.get_parameters(),
                    args.lr,
                    momentum=args.momentum,
                    weight_decay=args.weight_decay,
                    nesterov=True)
    optimizer_d = SGD(domain_discri.get_parameters(),
                      args.lr_d,
                      momentum=args.momentum,
                      weight_decay=args.weight_decay,
                      nesterov=True)
    lr_scheduler = LambdaLR(
        optimizer, lambda x: args.lr *
        (1. + args.lr_gamma * float(x))**(-args.lr_decay))
    lr_scheduler_d = LambdaLR(
        optimizer_d, lambda x: args.lr_d *
        (1. + args.lr_gamma * float(x))**(-args.lr_decay))

    # resume from the best checkpoint
    if args.phase != 'train':
        checkpoint = torch.load(logger.get_checkpoint_path('best'),
                                map_location='cpu')
        classifier.load_state_dict(checkpoint)

    # analysis the model
    if args.phase == 'analysis':
        # extract features from both domains
        feature_extractor = nn.Sequential(classifier.backbone,
                                          classifier.pool_layer,
                                          classifier.bottleneck).to(device)
        source_feature = collect_feature(train_source_loader,
                                         feature_extractor, device)
        target_feature = collect_feature(train_target_loader,
                                         feature_extractor, device)
        # plot t-SNE
        tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.pdf')
        tsne.visualize(source_feature, target_feature, tSNE_filename)
        print("Saving t-SNE to", tSNE_filename)
        # calculate A-distance, which is a measure for distribution discrepancy
        A_distance = a_distance.calculate(source_feature, target_feature,
                                          device)
        print("A-distance =", A_distance)
        return

    if args.phase == 'test':
        acc1 = utils.validate(test_loader, classifier, args, device)
        print(acc1)
        return

    # start training
    best_acc1 = 0.
    for epoch in range(args.epochs):
        print("lr classifier:", lr_scheduler.get_lr())
        print("lr discriminator:", lr_scheduler_d.get_lr())
        # train for one epoch
        train(train_source_iter, train_target_iter, classifier, domain_discri,
              domain_adv, gl, optimizer, lr_scheduler, optimizer_d,
              lr_scheduler_d, epoch, args)

        # evaluate on validation set
        acc1 = utils.validate(val_loader, classifier, args, device)

        # remember best acc@1 and save checkpoint
        torch.save(classifier.state_dict(),
                   logger.get_checkpoint_path('latest'))
        if acc1 > best_acc1:
            shutil.copy(logger.get_checkpoint_path('latest'),
                        logger.get_checkpoint_path('best'))
        best_acc1 = max(acc1, best_acc1)

    print("best_acc1 = {:3.1f}".format(best_acc1))

    # evaluate on test set
    classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best')))
    acc1 = utils.validate(test_loader, classifier, args, device)
    print("test_acc1 = {:3.1f}".format(acc1))

    logger.close()
Beispiel #6
0
def main(args: argparse.Namespace):
    logger = CompleteLogger(args.log, args.phase)
    print(args)

    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')

    cudnn.benchmark = True

    # Data loading code
    normalize = T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    train_transform = T.Compose([
        T.RandomRotation(args.rotation),
        T.RandomResizedCrop(size=args.image_size, scale=args.resize_scale),
        T.ColorJitter(brightness=0.25, contrast=0.25, saturation=0.25),
        T.GaussianBlur(),
        T.ToTensor(), normalize
    ])
    val_transform = T.Compose(
        [T.Resize(args.image_size),
         T.ToTensor(), normalize])
    image_size = (args.image_size, args.image_size)
    heatmap_size = (args.heatmap_size, args.heatmap_size)
    source_dataset = datasets.__dict__[args.source]
    train_source_dataset = source_dataset(root=args.source_root,
                                          transforms=train_transform,
                                          image_size=image_size,
                                          heatmap_size=heatmap_size)
    train_source_loader = DataLoader(train_source_dataset,
                                     batch_size=args.batch_size,
                                     shuffle=True,
                                     num_workers=args.workers,
                                     pin_memory=True,
                                     drop_last=True)
    val_source_dataset = source_dataset(root=args.source_root,
                                        split='test',
                                        transforms=val_transform,
                                        image_size=image_size,
                                        heatmap_size=heatmap_size)
    val_source_loader = DataLoader(val_source_dataset,
                                   batch_size=args.batch_size,
                                   shuffle=False,
                                   pin_memory=True)

    target_dataset = datasets.__dict__[args.target]
    train_target_dataset = target_dataset(root=args.target_root,
                                          transforms=train_transform,
                                          image_size=image_size,
                                          heatmap_size=heatmap_size)
    train_target_loader = DataLoader(train_target_dataset,
                                     batch_size=args.batch_size,
                                     shuffle=True,
                                     num_workers=args.workers,
                                     pin_memory=True,
                                     drop_last=True)
    val_target_dataset = target_dataset(root=args.target_root,
                                        split='test',
                                        transforms=val_transform,
                                        image_size=image_size,
                                        heatmap_size=heatmap_size)
    val_target_loader = DataLoader(val_target_dataset,
                                   batch_size=args.batch_size,
                                   shuffle=False,
                                   pin_memory=True)

    print("Source train:", len(train_source_loader))
    print("Target train:", len(train_target_loader))
    print("Source test:", len(val_source_loader))
    print("Target test:", len(val_target_loader))

    train_source_iter = ForeverDataIterator(train_source_loader)
    train_target_iter = ForeverDataIterator(train_target_loader)

    # create model
    backbone = models.__dict__[args.arch](pretrained=True)
    upsampling = Upsampling(backbone.out_features)
    num_keypoints = train_source_dataset.num_keypoints
    model = RegDAPoseResNet(backbone,
                            upsampling,
                            256,
                            num_keypoints,
                            num_head_layers=args.num_head_layers,
                            finetune=True).to(device)
    # define loss function
    criterion = JointsKLLoss()
    pseudo_label_generator = PseudoLabelGenerator(num_keypoints,
                                                  args.heatmap_size,
                                                  args.heatmap_size)
    regression_disparity = RegressionDisparity(pseudo_label_generator,
                                               JointsKLLoss(epsilon=1e-7))

    # define optimizer and lr scheduler
    optimizer_f = SGD([
        {
            'params': backbone.parameters(),
            'lr': 0.1
        },
        {
            'params': upsampling.parameters(),
            'lr': 0.1
        },
    ],
                      lr=0.1,
                      momentum=args.momentum,
                      weight_decay=args.wd,
                      nesterov=True)
    optimizer_h = SGD(model.head.parameters(),
                      lr=1.,
                      momentum=args.momentum,
                      weight_decay=args.wd,
                      nesterov=True)
    optimizer_h_adv = SGD(model.head_adv.parameters(),
                          lr=1.,
                          momentum=args.momentum,
                          weight_decay=args.wd,
                          nesterov=True)
    lr_decay_function = lambda x: args.lr * (1. + args.lr_gamma * float(x))**(
        -args.lr_decay)
    lr_scheduler_f = LambdaLR(optimizer_f, lr_decay_function)
    lr_scheduler_h = LambdaLR(optimizer_h, lr_decay_function)
    lr_scheduler_h_adv = LambdaLR(optimizer_h_adv, lr_decay_function)
    start_epoch = 0

    if args.resume is None:
        if args.pretrain is None:
            # first pretrain the backbone and upsampling
            print("Pretraining the model on source domain.")
            args.pretrain = logger.get_checkpoint_path('pretrain')
            pretrained_model = PoseResNet(backbone, upsampling, 256,
                                          num_keypoints, True).to(device)
            optimizer = SGD(pretrained_model.get_parameters(lr=args.lr),
                            momentum=args.momentum,
                            weight_decay=args.wd,
                            nesterov=True)
            lr_scheduler = MultiStepLR(optimizer, args.lr_step, args.lr_factor)
            best_acc = 0
            for epoch in range(args.pretrain_epochs):
                lr_scheduler.step()
                print(lr_scheduler.get_lr())

                pretrain(train_source_iter, pretrained_model, criterion,
                         optimizer, epoch, args)
                source_val_acc = validate(val_source_loader, pretrained_model,
                                          criterion, None, args)

                # remember best acc and save checkpoint
                if source_val_acc['all'] > best_acc:
                    best_acc = source_val_acc['all']
                    torch.save({'model': pretrained_model.state_dict()},
                               args.pretrain)
                print("Source: {} best: {}".format(source_val_acc['all'],
                                                   best_acc))

        # load from the pretrained checkpoint
        pretrained_dict = torch.load(args.pretrain,
                                     map_location='cpu')['model']
        model_dict = model.state_dict()
        # remove keys from pretrained dict that doesn't appear in model dict
        pretrained_dict = {
            k: v
            for k, v in pretrained_dict.items() if k in model_dict
        }
        model.load_state_dict(pretrained_dict, strict=False)
    else:
        # optionally resume from a checkpoint
        checkpoint = torch.load(args.resume, map_location='cpu')
        model.load_state_dict(checkpoint['model'])
        optimizer_f.load_state_dict(checkpoint['optimizer_f'])
        optimizer_h.load_state_dict(checkpoint['optimizer_h'])
        optimizer_h_adv.load_state_dict(checkpoint['optimizer_h_adv'])
        lr_scheduler_f.load_state_dict(checkpoint['lr_scheduler_f'])
        lr_scheduler_h.load_state_dict(checkpoint['lr_scheduler_h'])
        lr_scheduler_h_adv.load_state_dict(checkpoint['lr_scheduler_h_adv'])
        start_epoch = checkpoint['epoch'] + 1

    # define visualization function
    tensor_to_image = Compose([
        Denormalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ToPILImage()
    ])

    def visualize(image, keypoint2d, name, heatmaps=None):
        """
        Args:
            image (tensor): image in shape 3 x H x W
            keypoint2d (tensor): keypoints in shape K x 2
            name: name of the saving image
        """
        train_source_dataset.visualize(
            tensor_to_image(image), keypoint2d,
            logger.get_image_path("{}.jpg".format(name)))

    if args.phase == 'test':
        # evaluate on validation set
        source_val_acc = validate(val_source_loader, model, criterion, None,
                                  args)
        target_val_acc = validate(val_target_loader, model, criterion,
                                  visualize, args)
        print("Source: {:4.3f} Target: {:4.3f}".format(source_val_acc['all'],
                                                       target_val_acc['all']))
        for name, acc in target_val_acc.items():
            print("{}: {:4.3f}".format(name, acc))
        return

    # start training
    best_acc = 0
    print("Start regression domain adaptation.")
    for epoch in range(start_epoch, args.epochs):
        logger.set_epoch(epoch)
        print(lr_scheduler_f.get_lr(), lr_scheduler_h.get_lr(),
              lr_scheduler_h_adv.get_lr())

        # train for one epoch
        train(train_source_iter, train_target_iter, model, criterion,
              regression_disparity, optimizer_f, optimizer_h, optimizer_h_adv,
              lr_scheduler_f, lr_scheduler_h, lr_scheduler_h_adv, epoch,
              visualize if args.debug else None, args)

        # evaluate on validation set
        source_val_acc = validate(val_source_loader, model, criterion, None,
                                  args)
        target_val_acc = validate(val_target_loader, model, criterion,
                                  visualize if args.debug else None, args)

        # remember best acc and save checkpoint
        torch.save(
            {
                'model': model.state_dict(),
                'optimizer_f': optimizer_f.state_dict(),
                'optimizer_h': optimizer_h.state_dict(),
                'optimizer_h_adv': optimizer_h_adv.state_dict(),
                'lr_scheduler_f': lr_scheduler_f.state_dict(),
                'lr_scheduler_h': lr_scheduler_h.state_dict(),
                'lr_scheduler_h_adv': lr_scheduler_h_adv.state_dict(),
                'epoch': epoch,
                'args': args
            }, logger.get_checkpoint_path(epoch))
        if target_val_acc['all'] > best_acc:
            shutil.copy(logger.get_checkpoint_path(epoch),
                        logger.get_checkpoint_path('best'))
            best_acc = target_val_acc['all']
        print("Source: {:4.3f} Target: {:4.3f} Target(best): {:4.3f}".format(
            source_val_acc['all'], target_val_acc['all'], best_acc))
        for name, acc in target_val_acc.items():
            print("{}: {:4.3f}".format(name, acc))

    logger.close()
Beispiel #7
0
def main(args: argparse.Namespace):
    logger = CompleteLogger(args.log, args.phase)

    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')

    cudnn.benchmark = True

    # Data loading code
    if args.num_channels == 3:
        mode = 'RGB'
        mean = std = [0.5, 0.5, 0.5]
    else:
        mode = 'L'
        mean = std = [
            0.5,
        ]
    normalize = T.Normalize(mean=mean, std=std)

    train_transform = T.Compose([
        ResizeImage(args.image_size),
        # T.RandomRotation(10), # TODO need results
        T.ToTensor(),
        normalize
    ])
    val_transform = T.Compose(
        [ResizeImage(args.image_size),
         T.ToTensor(), normalize])

    source_dataset = datasets.__dict__[args.source]
    train_source_dataset = source_dataset(root=args.source_root,
                                          mode=mode,
                                          download=True,
                                          transform=train_transform)
    train_source_loader = DataLoader(train_source_dataset,
                                     batch_size=args.batch_size,
                                     shuffle=True,
                                     num_workers=args.workers,
                                     drop_last=True)
    target_dataset = datasets.__dict__[args.target]
    train_target_dataset = target_dataset(root=args.target_root,
                                          mode=mode,
                                          download=True,
                                          transform=train_transform)
    train_target_loader = DataLoader(train_target_dataset,
                                     batch_size=args.batch_size,
                                     shuffle=True,
                                     num_workers=args.workers,
                                     drop_last=True)
    val_dataset = target_dataset(root=args.target_root,
                                 mode=mode,
                                 split='test',
                                 download=True,
                                 transform=val_transform)
    val_loader = DataLoader(val_dataset,
                            batch_size=args.batch_size,
                            shuffle=False,
                            num_workers=args.workers)

    train_source_iter = ForeverDataIterator(train_source_loader)
    train_target_iter = ForeverDataIterator(train_target_loader)

    # create model
    print("=> using pre-trained model '{}'".format(args.arch))
    arch = models.__dict__[args.arch]()
    bottleneck = nn.Sequential(
        nn.Flatten(), nn.Linear(arch.bottleneck_dim, arch.bottleneck_dim),
        nn.BatchNorm1d(arch.bottleneck_dim), nn.ReLU(), nn.Dropout(0.5))
    head = arch.head()
    adv_head = arch.head()
    classifier = GeneralModule(arch.backbone(),
                               arch.num_classes,
                               bottleneck,
                               head,
                               adv_head,
                               finetune=False)
    mdd = MarginDisparityDiscrepancy(args.margin).to(device)

    # define optimizer and lr scheduler
    optimizer = Adam(classifier.get_parameters(),
                     args.lr,
                     betas=args.betas,
                     weight_decay=args.wd)
    lr_scheduler = LambdaLR(
        optimizer, lambda x: args.lr *
        (1. + args.lr_gamma * float(x))**(-args.lr_decay))

    # resume from the best checkpoint
    if args.phase != 'train':
        checkpoint = torch.load(logger.get_checkpoint_path('best'),
                                map_location='cpu')
        classifier.load_state_dict(checkpoint)

    # analysis the model
    if args.phase == 'analysis':
        # extract features from both domains
        feature_extractor = torch.nn.Sequential(
            classifier.backbone, classifier.bottleneck).to(device)
        source_feature = collect_feature(train_source_loader,
                                         feature_extractor, device, 10)
        target_feature = collect_feature(val_loader, feature_extractor, device,
                                         10)
        # plot t-SNE
        tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.png')
        tsne.visualize(source_feature, target_feature, tSNE_filename)
        print("Saving t-SNE to", tSNE_filename)
        # calculate A-distance, which is a measure for distribution discrepancy
        A_distance = a_distance.calculate(source_feature, target_feature,
                                          device)
        print("A-distance =", A_distance)
        return

    if args.phase == 'test':
        acc1 = validate(val_loader, classifier, args)
        print(acc1)
        return

    # start training
    best_acc1 = 0.
    for epoch in range(args.epochs):
        print(lr_scheduler.get_lr())
        # train for one epoch
        train(train_source_iter, train_target_iter, classifier, mdd, optimizer,
              lr_scheduler, epoch, args)

        # evaluate on validation set
        acc1 = validate(val_loader, classifier, args)

        # remember best acc@1 and save checkpoint
        torch.save(classifier.state_dict(),
                   logger.get_checkpoint_path('latest'))
        if acc1 > best_acc1:
            shutil.copy(logger.get_checkpoint_path('latest'),
                        logger.get_checkpoint_path('best'))
        best_acc1 = max(acc1, best_acc1)

    print("best_acc1 = {:3.1f}".format(best_acc1))

    logger.close()
def train(model,
          optimizer,
          train_data,
          val_data,
          params,
          metric=accuracy_score,
          criterion=nn.CrossEntropyLoss(),
          variable_created_by_model=True):

    mean_train_loss = []
    mean_val_loss = []
    mean_train_metric = []
    mean_val_metric = []

    scheduler = LambdaLR(
        optimizer,
        lr_lambda=lambda epoch: 0.5**(epoch // params["lr_ep_step"]))

    for epoch in range(params["epochs"]):
        start_time = time.time()

        scheduler.step()
        print("current lr = {}".format(scheduler.get_lr()[0]))

        train_loss, train_preds, train_targets = train_one_epoch(
            model, optimizer, train_data, params, criterion,
            variable_created_by_model)
        val_loss, val_preds, val_targets = validate(model, val_data, params,
                                                    criterion,
                                                    variable_created_by_model)

        # print the results for this epoch:
        mean_train_loss.append(np.mean(train_loss))
        mean_val_loss.append(np.mean(val_loss))
        mean_train_metric.append(metric(train_targets, train_preds))
        mean_val_metric.append(metric(val_targets, val_preds))

        clear_output(True)
        plt.figure(figsize=(10, 5))
        plt.subplot(121)
        plt.plot(mean_train_loss)
        plt.plot(mean_val_loss)
        plt.subplot(122)
        plt.plot(mean_train_metric)
        plt.plot(mean_val_metric)
        plt.gca().set_ylim([0, 1])
        plt.show()
        print("Epoch {} of {} took {:.3f}s".format(epoch + 1, params["epochs"],
                                                   time.time() - start_time))
        print("  training loss (in-iteration): \t{:.6f}".format(
            mean_train_loss[-1]))
        print("  validation loss: \t\t\t{:.6f}".format(mean_val_loss[-1]))
        print("  training metric: \t\t\t{:.2f}".format(mean_train_metric[-1]))
        print("  validation metric: \t\t\t{:.2f}".format(mean_val_metric[-1]))

#         if mean_train_loss[-1] < epsilon:
#             break

    return mean_train_loss, mean_val_loss, mean_train_metric, mean_val_metric


# ? def cross_val_trains
Beispiel #9
0
def train(model, tokenizer, train_data, valid_data, args):
    model.train()

    train_dataset = TextDataset(train_data)
    train_dataloader = DataLoader(train_dataset,
                                  sampler=RandomSampler(train_dataset),
                                  batch_size=args.train_batch_size,
                                  num_workers=args.num_workers,
                                  collate_fn=lambda x: collate_fn_bert(
                                      x, tokenizer, args.max_seq_length))

    valid_dataset = TextDataset(valid_data)
    valid_dataloader = DataLoader(valid_dataset,
                                  sampler=SequentialSampler(valid_dataset),
                                  batch_size=args.eval_batch_size,
                                  num_workers=args.num_workers,
                                  collate_fn=lambda x: collate_fn_bert(
                                      x, tokenizer, args.max_seq_length))

    valid_noisy = [x['noisy'] for x in valid_data]
    valid_clean = [x['clean'] for x in valid_data]

    epochs = (args.max_steps - 1) // len(train_dataloader) + 1
    # optimizer = torch.optim.Adam(model.parameters(), lr=args.lr,
    #                              betas=eval(args.adam_betas), eps=args.eps,
    #                              weight_decay=args.weight_decay)
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
    lr_lambda = lambda x: x / args.num_warmup_steps if x <= args.num_warmup_steps else (
        x / args.num_warmup_steps)**-0.5
    scheduler = LambdaLR(optimizer, lr_lambda)

    step = 0
    best_val_gleu = -float("inf")
    meter = Meter()
    for epoch in range(1, epochs + 1):
        for batch in train_dataloader:

            step += 1
            batch = tuple(t.to(args.device) for t in batch)
            noise_input_ids, clean_input_ids, noise_mask, clean_mask = batch
            #print("noise shape: {}, clean shape: {}".format(noise_input_ids.shape, clean_input_ids.shape))

            outputs = model(noise_input_ids,
                            labels=clean_input_ids,
                            attention_mask=noise_mask)
            loss = outputs[0]
            predict_score = outputs[1]

            bsz = clean_input_ids.size(0)
            items = [loss.data.item(), bsz, clean_mask.sum().item()]
            #print("items: ", items)
            meter.add(*items)

            loss.backward()
            if args.max_grad_norm > 0:
                nn.utils.clip_grad_norm_(model.parameters(),
                                         args.max_grad_norm)
            optimizer.step()
            model.zero_grad()
            scheduler.step()

            if step % args.log_interval == 0:
                lr = scheduler.get_lr()[0]
                loss_sent, loss_token = meter.average()

                logger.info(
                    f' [{step:5d}] lr {lr:.6f} | {meter.print_str(True)}')
                nsml.report(step=step,
                            scope=locals(),
                            summary=True,
                            train__lr=lr,
                            train__loss_sent=loss_sent,
                            train__token_ppl=math.exp(loss_token))
                meter.init()

            if step % args.eval_interval == 0:
                start_eval = time.time()
                (val_loss, val_loss_token), valid_str = evaluate_kcBert(
                    model, valid_dataloader, args)
                prediction = correct_kcBert(model,
                                            tokenizer,
                                            valid_noisy,
                                            args,
                                            length_limit=0.1)
                val_em = em(prediction, valid_clean)
                cnt = 0
                for noisy, pred, clean in zip(valid_noisy, prediction,
                                              valid_clean):
                    print(f'[{noisy}], [{pred}], [{clean}]')
                    # 10개만 출력하기
                    cnt += 1
                    if cnt == 20:
                        break
                # print("len of prediction: {}, len of valid_clean: {}", len(prediction), len(valid_clean))
                val_gleu = gleu(prediction, valid_clean)

                logger.info('-' * 89)
                logger.info(
                    f' [{step:6d}] valid | {valid_str} | em {val_em:5.2f} | gleu {val_gleu:5.2f}'
                )
                logger.info('-' * 89)
                nsml.report(step=step,
                            scope=locals(),
                            summary=True,
                            valid__loss_sent=val_loss,
                            valid__token_ppl=math.exp(val_loss_token),
                            valid__em=val_em,
                            valid__gleu=val_gleu)

                if val_gleu > best_val_gleu:
                    best_val_gleu = val_gleu
                    nsml.save("best")
                meter.start += time.time() - start_eval

            if step >= args.max_steps:
                break
        if step >= args.max_steps:
            break
Beispiel #10
0
class PPOAgent(BaseAgent):
    actor: nn.Module
    critic: nn.Module
    same_body: float = False

    def __post_init__(self):
        move_to([self.actor, self.critic], device=cfg.alg.device)
        if cfg.alg.vf_loss_type == 'mse':
            self.val_loss_criterion = nn.MSELoss().to(cfg.alg.device)
        elif cfg.alg.vf_loss_type == 'smoothl1':
            self.val_loss_criterion = nn.SmoothL1Loss().to(cfg.alg.device)
        else:
            raise TypeError(
                f'Unknown value loss type: {cfg.alg.vf_loss_type}!')
        all_params = list(self.actor.parameters()) + list(
            self.critic.parameters())
        # keep unique elements only. The following code works for python >=3.7
        # for earlier version of python, u need to use OrderedDict
        self.all_params = dict.fromkeys(all_params).keys()
        if (cfg.alg.linear_decay_lr or cfg.alg.linear_decay_clip_range) and \
                cfg.alg.max_steps > cfg.alg.max_decay_steps:
            logger.warning(
                'max_steps should not be greater than max_decay_steps.')
            cfg.alg.max_decay_steps = int(cfg.alg.max_steps * 1.5)
            logger.warning(
                f'Resetting max_decay_steps to {cfg.alg.max_decay_steps}!')
        total_epochs = int(
            np.ceil(cfg.alg.max_decay_steps /
                    (cfg.alg.num_envs * cfg.alg.episode_steps)))
        if cfg.alg.linear_decay_clip_range:
            self.clip_range_decay_rate = cfg.alg.clip_range / float(
                total_epochs)

        p_lr_lambda = partial(linear_decay_percent, total_epochs=total_epochs)
        optim_args = dict(lr=cfg.alg.policy_lr,
                          weight_decay=cfg.alg.weight_decay)
        if not cfg.alg.sgd:
            optim_args['amsgrad'] = cfg.alg.use_amsgrad
            optim_func = optim.Adam
        else:
            optim_args['nesterov'] = True if cfg.alg.momentum > 0 else False
            optim_args['momentum'] = cfg.alg.momentum
            optim_func = optim.SGD
        if self.same_body:
            optim_args['params'] = self.all_params
        else:
            optim_args['params'] = [{
                'params': self.actor.parameters(),
                'lr': cfg.alg.policy_lr
            }, {
                'params': self.critic.parameters(),
                'lr': cfg.alg.value_lr
            }]

        self.optimizer = optim_func(**optim_args)

        if self.same_body:
            self.lr_scheduler = LambdaLR(optimizer=self.optimizer,
                                         lr_lambda=[p_lr_lambda])
        else:
            v_lr_lambda = partial(linear_decay_percent,
                                  total_epochs=total_epochs)
            self.lr_scheduler = LambdaLR(optimizer=self.optimizer,
                                         lr_lambda=[p_lr_lambda, v_lr_lambda])

    @torch.no_grad()
    def get_action(self, ob, sample=True, *args, **kwargs):
        self.eval_mode()
        if type(ob) is dict:
            t_ob = {
                key: torch_float(ob[key], device=cfg.alg.device)
                for key in ob
            }
        else:
            t_ob = torch_float(ob, device=cfg.alg.device)

        act_dist, val = self.get_act_val(t_ob)
        action = action_from_dist(act_dist, sample=sample)
        log_prob = action_log_prob(action, act_dist)
        entropy = action_entropy(act_dist, log_prob)
        action_info = dict(log_prob=torch_to_np(log_prob),
                           entropy=torch_to_np(entropy),
                           val=torch_to_np(val))
        return torch_to_np(action), action_info

    def get_act_val(self, ob, *args, **kwargs):
        if type(ob) is dict:
            ob = {
                key: torch_float(ob[key], device=cfg.alg.device)
                for key in ob
            }
        else:
            ob = torch_float(ob, device=cfg.alg.device)

        act_dist, body_out = self.actor(ob)
        if self.same_body:
            val, body_out = self.critic(body_x=body_out)
        else:
            val, body_out = self.critic(x=ob)
        val = val.squeeze(-1)
        return act_dist, val

    @torch.no_grad()
    def get_val(self, ob, *args, **kwargs):
        self.eval_mode()

        if type(ob) is dict:
            ob = {
                key: torch_float(ob[key], device=cfg.alg.device)
                for key in ob
            }
        else:
            ob = torch_float(ob, device=cfg.alg.device)

        val, body_out = self.critic(x=ob)
        val = val.squeeze(-1)
        return val

    def optimize(self, data, *args, **kwargs):
        pre_res = self.optim_preprocess(data)
        processed_data = pre_res
        processed_data['entropy'] = torch.mean(processed_data['entropy'])
        loss_res = self.cal_loss(**processed_data)
        loss, pg_loss, vf_loss, ratio = loss_res
        self.optimizer.zero_grad()
        loss.backward()

        grad_norm = clip_grad(self.all_params, cfg.alg.max_grad_norm)
        self.optimizer.step()
        with torch.no_grad():
            approx_kl = 0.5 * torch.mean(
                torch.pow(
                    processed_data['old_log_prob'] -
                    processed_data['log_prob'], 2))
            clip_frac = np.mean(
                np.abs(torch_to_np(ratio) - 1.0) > cfg.alg.clip_range)
        optim_info = dict(pg_loss=pg_loss.item(),
                          vf_loss=vf_loss.item(),
                          total_loss=loss.item(),
                          entropy=processed_data['entropy'].item(),
                          approx_kl=approx_kl.item(),
                          clip_frac=clip_frac)
        optim_info['grad_norm'] = grad_norm
        return optim_info

    def optim_preprocess(self, data):
        self.train_mode()
        for key, val in data.items():
            data[key] = torch_float(val, device=cfg.alg.device)
        ob = data['ob']
        state = data['state']
        action = data['action']
        ret = data['ret']
        adv = data['adv']
        old_log_prob = data['log_prob']
        old_val = data['val']

        act_dist, val = self.get_act_val({"ob": ob, "state": state})
        log_prob = action_log_prob(action, act_dist)
        entropy = action_entropy(act_dist, log_prob)
        if not all([x.ndim == 1 for x in [val, entropy, log_prob]]):
            raise ValueError('val, entropy, log_prob should be 1-dim!')
        processed_data = dict(val=val,
                              old_val=old_val,
                              ret=ret,
                              log_prob=log_prob,
                              old_log_prob=old_log_prob,
                              adv=adv,
                              entropy=entropy)
        return processed_data

    def cal_loss(self, val, old_val, ret, log_prob, old_log_prob, adv,
                 entropy):
        vf_loss = self.cal_val_loss(val=val, old_val=old_val, ret=ret)
        ratio = torch.exp(log_prob - old_log_prob)
        surr1 = adv * ratio
        surr2 = adv * torch.clamp(ratio, 1 - cfg.alg.clip_range,
                                  1 + cfg.alg.clip_range)
        pg_loss = -torch.mean(torch.min(surr1, surr2))

        loss = pg_loss - entropy * cfg.alg.ent_coef + \
               vf_loss * cfg.alg.vf_coef
        return loss, pg_loss, vf_loss, ratio

    def cal_val_loss(self, val, old_val, ret):
        if cfg.alg.clip_vf_loss:
            clipped_val = old_val + torch.clamp(
                val - old_val, -cfg.alg.clip_range, cfg.alg.clip_range)
            vf_loss1 = torch.pow(val - ret, 2)
            vf_loss2 = torch.pow(clipped_val - ret, 2)
            vf_loss = 0.5 * torch.mean(torch.max(vf_loss1, vf_loss2))
        else:
            # val = torch.squeeze(val)
            vf_loss = 0.5 * self.val_loss_criterion(val, ret)
        return vf_loss

    def train_mode(self):
        self.actor.train()
        self.critic.train()

    def eval_mode(self):
        self.actor.eval()
        self.critic.eval()

    def decay_lr(self):
        self.lr_scheduler.step()

    def get_lr(self):
        cur_lr = self.lr_scheduler.get_lr()
        lrs = {'policy_lr': cur_lr[0]}
        if len(cur_lr) > 1:
            lrs['value_lr'] = cur_lr[1]
        return lrs

    def decay_clip_range(self):
        cfg.alg.clip_range -= self.clip_range_decay_rate

    def save_model(self, is_best=False, step=None):
        self.save_env(cfg.alg.model_dir)
        data_to_save = {
            'step': step,
            'actor_state_dict': self.actor.state_dict(),
            'critic_state_dict': self.critic.state_dict(),
            'optim_state_dict': self.optimizer.state_dict(),
            'lr_scheduler_state_dict': self.lr_scheduler.state_dict()
        }

        if cfg.alg.linear_decay_clip_range:
            data_to_save['clip_range'] = cfg.alg.clip_range
            data_to_save['clip_range_decay_rate'] = self.clip_range_decay_rate
        save_model(data_to_save, cfg.alg, is_best=is_best, step=step)

    def load_model(self, step=None, pretrain_model=None):
        self.load_env(cfg.alg.model_dir)
        ckpt_data = load_ckpt_data(cfg.alg,
                                   step=step,
                                   pretrain_model=pretrain_model)
        load_state_dict(self.actor, ckpt_data['actor_state_dict'])
        load_state_dict(self.critic, ckpt_data['critic_state_dict'])
        if pretrain_model is not None:
            return
        self.optimizer.load_state_dict(ckpt_data['optim_state_dict'])
        self.lr_scheduler.load_state_dict(ckpt_data['lr_scheduler_state_dict'])
        if cfg.alg.linear_decay_clip_range:
            self.clip_range_decay_rate = ckpt_data['clip_range_decay_rate']
            cfg.alg.clip_range = ckpt_data['clip_range']
        return ckpt_data['step']

    def print_param_grad_status(self):
        logger.info('Requires Grad?')
        logger.info('================== Actor ================== ')
        for name, param in self.actor.named_parameters():
            print(f'{name}: {param.requires_grad}')
        logger.info('================== Critic ================== ')
        for name, param in self.critic.named_parameters():
            print(f'{name}: {param.requires_grad}')
Beispiel #11
0
def main(args: argparse.Namespace):
    logger = CompleteLogger(args.log, args.phase)
    print(args)

    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')

    cudnn.benchmark = True

    # Data loading code
    normalize = T.Normalize(mean=[0.485, 0.456, 0.406],
                            std=[0.229, 0.224, 0.225])
    train_transform = T.Compose([
        ResizeImage(256),
        T.RandomCrop(224),
        T.RandomHorizontalFlip(),
        T.ColorJitter(brightness=0.7, contrast=0.7, saturation=0.7, hue=0.5),
        T.RandomGrayscale(),
        T.ToTensor(), normalize
    ])
    val_transform = T.Compose(
        [ResizeImage(256),
         T.CenterCrop(224),
         T.ToTensor(), normalize])

    train_source_dataset, train_target_dataset, val_dataset, test_dataset, num_classes, args.class_names = \
        utils.get_dataset(args.data, args.root, args.source, args.target,
                          train_transform, val_transform, MultipleApply([train_transform, val_transform]))
    train_source_loader = DataLoader(train_source_dataset,
                                     batch_size=args.batch_size,
                                     shuffle=True,
                                     num_workers=args.workers,
                                     drop_last=True)
    train_target_loader = DataLoader(train_target_dataset,
                                     batch_size=args.batch_size,
                                     shuffle=True,
                                     num_workers=args.workers,
                                     drop_last=True)
    val_loader = DataLoader(val_dataset,
                            batch_size=args.batch_size,
                            shuffle=False,
                            num_workers=args.workers)
    test_loader = DataLoader(test_dataset,
                             batch_size=args.batch_size,
                             shuffle=False,
                             num_workers=args.workers)

    train_source_iter = ForeverDataIterator(train_source_loader)
    train_target_iter = ForeverDataIterator(train_target_loader)

    # create model
    print("=> using model '{}'".format(args.arch))
    backbone = utils.get_model(args.arch, pretrain=not args.scratch)
    pool_layer = nn.Identity() if args.no_pool else None
    classifier = ImageClassifier(backbone,
                                 num_classes,
                                 bottleneck_dim=args.bottleneck_dim,
                                 pool_layer=pool_layer,
                                 finetune=not args.scratch).to(device)

    # define optimizer and lr scheduler
    optimizer = Adam(classifier.get_parameters(), args.lr)
    lr_scheduler = LambdaLR(
        optimizer, lambda x: args.lr *
        (1. + args.lr_gamma * float(x))**(-args.lr_decay))

    # resume from the best checkpoint
    if args.phase != 'train':
        checkpoint = torch.load(logger.get_checkpoint_path('best'),
                                map_location='cpu')
        classifier.load_state_dict(checkpoint)

    # analysis the model
    if args.phase == 'analysis':
        # extract features from both domains
        feature_extractor = nn.Sequential(classifier.backbone,
                                          classifier.pool_layer,
                                          classifier.bottleneck).to(device)
        source_feature = collect_feature(train_source_loader,
                                         feature_extractor, device)
        target_feature = collect_feature(train_target_loader,
                                         feature_extractor, device)
        # plot t-SNE
        tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.pdf')
        tsne.visualize(source_feature, target_feature, tSNE_filename)
        print("Saving t-SNE to", tSNE_filename)
        # calculate A-distance, which is a measure for distribution discrepancy
        A_distance = a_distance.calculate(source_feature, target_feature,
                                          device)
        print("A-distance =", A_distance)
        return

    if args.phase == 'test':
        acc1 = utils.validate(test_loader, classifier, args, device)
        print(acc1)
        return

    if args.pretrain is None:
        # first pretrain the classifier wish source data
        print("Pretraining the model on source domain.")
        args.pretrain = logger.get_checkpoint_path('pretrain')
        pretrain_model = ImageClassifier(backbone,
                                         num_classes,
                                         bottleneck_dim=args.bottleneck_dim,
                                         pool_layer=pool_layer,
                                         finetune=not args.scratch).to(device)
        pretrain_optimizer = Adam(pretrain_model.get_parameters(),
                                  args.pretrain_lr)
        pretrain_lr_scheduler = LambdaLR(
            pretrain_optimizer, lambda x: args.pretrain_lr *
            (1. + args.lr_gamma * float(x))**(-args.lr_decay))

        # start pretraining
        for epoch in range(args.pretrain_epochs):
            # pretrain for one epoch
            utils.pretrain(train_source_iter, pretrain_model,
                           pretrain_optimizer, pretrain_lr_scheduler, epoch,
                           args, device)
            # validate to show pretrain process
            utils.validate(val_loader, pretrain_model, args, device)

        torch.save(pretrain_model.state_dict(), args.pretrain)
        print("Pretraining process is done.")

    checkpoint = torch.load(args.pretrain, map_location='cpu')
    classifier.load_state_dict(checkpoint)
    teacher = EmaTeacher(classifier, alpha=args.alpha)
    consistent_loss = L2ConsistencyLoss().to(device)
    class_balance_loss = ClassBalanceLoss(num_classes).to(device)

    # start training
    best_acc1 = 0.
    for epoch in range(args.epochs):
        print(lr_scheduler.get_lr())
        # train for one epoch
        train(train_source_iter, train_target_iter, classifier, teacher,
              consistent_loss, class_balance_loss, optimizer, lr_scheduler,
              epoch, args)

        # evaluate on validation set
        acc1 = utils.validate(val_loader, classifier, args, device)

        # remember best acc@1 and save checkpoint
        torch.save(classifier.state_dict(),
                   logger.get_checkpoint_path('latest'))
        if acc1 > best_acc1:
            shutil.copy(logger.get_checkpoint_path('latest'),
                        logger.get_checkpoint_path('best'))
        best_acc1 = max(acc1, best_acc1)

    print("best_acc1 = {:3.1f}".format(best_acc1))

    # evaluate on test set
    classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best')))
    acc1 = utils.validate(test_loader, classifier, args, device)
    print("test_acc1 = {:3.1f}".format(acc1))

    logger.close()
Beispiel #12
0
    def train(self) -> None:
        r"""Main method for training PPO.

        Returns:
            None
        """
        logger.info(f"config: {self.config}")
        random.seed(self.config.SEED)
        np.random.seed(self.config.SEED)
        torch.manual_seed(self.config.SEED)

        # add_signal_handlers()

        self.envs = construct_envs(self.config,
                                   get_env_class(self.config.ENV_NAME),
                                   workers_ignore_signals=True)

        ppo_cfg = self.config.RL.PPO
        self.device = (torch.device("cuda", self.config.TORCH_GPU_ID)
                       if torch.cuda.is_available() else torch.device("cpu"))
        if not os.path.isdir(self.config.CHECKPOINT_FOLDER):
            os.makedirs(self.config.CHECKPOINT_FOLDER)
        self._setup_actor_critic_agent(ppo_cfg)
        logger.info("agent number of parameters: {}".format(
            sum(param.numel() for param in self.agent.parameters())))

        if ppo_cfg.use_external_memory:
            memory_dim = self.actor_critic.net.memory_dim
        else:
            memory_dim = None

        rollouts = RolloutStorage(
            ppo_cfg.num_steps,
            self.envs.num_envs,
            self.envs.observation_spaces[0],
            self.envs.action_spaces[0],
            ppo_cfg.hidden_size,
            ppo_cfg.use_external_memory,
            ppo_cfg.SCENE_MEMORY_TRANSFORMER.memory_size + ppo_cfg.num_steps,
            ppo_cfg.SCENE_MEMORY_TRANSFORMER.memory_size,
            memory_dim,
        )
        rollouts.to(self.device)

        observations = self.envs.reset()
        batch = batch_obs(observations)
        if self.config.RL.PPO.use_belief_predictor:
            self.belief_predictor.update(batch, None)

        for sensor in rollouts.observations:
            rollouts.observations[sensor][0].copy_(batch[sensor])

        # batch and observations may contain shared PyTorch CUDA
        # tensors.  We must explicitly clear them here otherwise
        # they will be kept in memory for the entire duration of training!
        batch = None
        observations = None

        current_episode_reward = torch.zeros(self.envs.num_envs, 1)
        running_episode_stats = dict(
            count=torch.zeros(self.envs.num_envs, 1),
            reward=torch.zeros(self.envs.num_envs, 1),
        )
        window_episode_stats = defaultdict(
            lambda: deque(maxlen=ppo_cfg.reward_window_size))

        t_start = time.time()
        env_time = 0
        pth_time = 0
        count_steps = 0
        count_checkpoints = 0
        start_update = 0
        prev_time = 0

        lr_scheduler = LambdaLR(
            optimizer=self.agent.optimizer,
            lr_lambda=lambda x: linear_decay(x, self.config.NUM_UPDATES),
        )

        interrupted_state = load_interrupted_state(
            model_dir=self.config.MODEL_DIR)
        if interrupted_state is not None:
            self.agent.load_state_dict(interrupted_state["state_dict"])
            self.agent.optimizer.load_state_dict(
                interrupted_state["optimizer_state"])
            lr_scheduler.load_state_dict(
                interrupted_state["lr_scheduler_state"])

            requeue_stats = interrupted_state["requeue_stats"]
            env_time = requeue_stats["env_time"]
            pth_time = requeue_stats["pth_time"]
            count_steps = requeue_stats["count_steps"]
            count_checkpoints = requeue_stats["count_checkpoints"]
            start_update = requeue_stats["start_update"]
            prev_time = requeue_stats["prev_time"]

        with TensorboardWriter(self.config.TENSORBOARD_DIR,
                               flush_secs=self.flush_secs) as writer:
            for update in range(start_update, self.config.NUM_UPDATES):
                if ppo_cfg.use_linear_lr_decay:
                    lr_scheduler.step()

                if ppo_cfg.use_linear_clip_decay:
                    self.agent.clip_param = ppo_cfg.clip_param * linear_decay(
                        update, self.config.NUM_UPDATES)

                if EXIT.is_set():
                    self.envs.close()

                    if REQUEUE.is_set():
                        requeue_stats = dict(
                            env_time=env_time,
                            pth_time=pth_time,
                            count_steps=count_steps,
                            count_checkpoints=count_checkpoints,
                            start_update=update,
                            prev_time=(time.time() - t_start) + prev_time,
                        )
                        save_interrupted_state(dict(
                            state_dict=self.agent.state_dict(),
                            optimizer_state=self.agent.optimizer.state_dict(),
                            lr_scheduler_state=lr_scheduler.state_dict(),
                            config=self.config,
                            requeue_stats=requeue_stats,
                        ),
                                               model_dir=self.config.MODEL_DIR)
                        requeue_job()
                    return

                for step in range(ppo_cfg.num_steps):
                    delta_pth_time, delta_env_time, delta_steps = self._collect_rollout_step(
                        rollouts, current_episode_reward,
                        running_episode_stats)
                    pth_time += delta_pth_time
                    env_time += delta_env_time
                    count_steps += delta_steps

                delta_pth_time, value_loss, action_loss, dist_entropy = self._update_agent(
                    ppo_cfg, rollouts)
                pth_time += delta_pth_time

                deltas = {
                    k:
                    ((v[-1] -
                      v[0]).sum().item() if len(v) > 1 else v[0].sum().item())
                    for k, v in window_episode_stats.items()
                }
                deltas["count"] = max(deltas["count"], 1.0)

                writer.add_scalar("Metrics/reward",
                                  deltas["reward"] / deltas["count"],
                                  count_steps)

                # Check to see if there are any metrics
                # that haven't been logged yet
                metrics = {
                    k: v / deltas["count"]
                    for k, v in deltas.items() if k not in {"reward", "count"}
                }
                if len(metrics) > 0:
                    # writer.add_scalars("metrics", metrics, count_steps)
                    for metric, value in metrics.items():
                        writer.add_scalar(f"Metrics/{metric}", value,
                                          count_steps)

                writer.add_scalar("Policy/value_loss", value_loss, count_steps)
                writer.add_scalar("Policy/policy_loss", action_loss,
                                  count_steps)
                writer.add_scalar("Policy/entropy_loss", dist_entropy,
                                  count_steps)
                writer.add_scalar('Policy/learning_rate',
                                  lr_scheduler.get_lr()[0], count_steps)

                # log stats
                if update > 0 and update % self.config.LOG_INTERVAL == 0:
                    logger.info("update: {}\tfps: {:.3f}\t".format(
                        update, count_steps / (time.time() - t_start)))

                    logger.info(
                        "update: {}\tenv-time: {:.3f}s\tpth-time: {:.3f}s\t"
                        "frames: {}".format(update, env_time, pth_time,
                                            count_steps))

                    logger.info("Average window size: {}  {}".format(
                        len(window_episode_stats["count"]),
                        "  ".join("{}: {:.3f}".format(k, v / deltas["count"])
                                  for k, v in deltas.items() if k != "count"),
                    ))

                # checkpoint model
                if update % self.config.CHECKPOINT_INTERVAL == 0:
                    self.save_checkpoint(f"ckpt.{count_checkpoints}.pth")
                    count_checkpoints += 1

            self.envs.close()
Beispiel #13
0
def train(data_path: str,
          data_directory: str,
          generate_vocabularies: bool,
          input_vocab_path: str,
          target_vocab_path: str,
          embedding_dimension: int,
          num_encoder_layers: int,
          encoder_dropout_p: float,
          encoder_bidirectional: bool,
          training_batch_size: int,
          test_batch_size: int,
          max_decoding_steps: int,
          num_decoder_layers: int,
          decoder_dropout_p: float,
          cnn_kernel_size: int,
          cnn_dropout_p: float,
          cnn_hidden_num_channels: int,
          simple_situation_representation: bool,
          decoder_hidden_size: int,
          encoder_hidden_size: int,
          learning_rate: float,
          adam_beta_1: float,
          adam_beta_2: float,
          lr_decay: float,
          lr_decay_steps: int,
          resume_from_file: str,
          max_training_iterations: int,
          output_directory: str,
          print_every: int,
          evaluate_every: int,
          conditional_attention: bool,
          auxiliary_task: bool,
          weight_target_loss: float,
          attention_type: str,
          max_training_examples=None,
          seed=42,
          **kwargs):
    device = torch.device(type='cuda') if use_cuda else torch.device(
        type='cpu')
    cfg = locals().copy()

    torch.manual_seed(seed)

    logger.info("Loading Training set...")
    training_set = GroundedScanDataset(
        data_path,
        data_directory,
        split="train",
        input_vocabulary_file=input_vocab_path,
        target_vocabulary_file=target_vocab_path,
        generate_vocabulary=generate_vocabularies)
    training_set.read_dataset(
        max_examples=max_training_examples,
        simple_situation_representation=simple_situation_representation)
    logger.info("Done Loading Training set.")
    logger.info("  Loaded {} training examples.".format(
        training_set.num_examples))
    logger.info("  Input vocabulary size training set: {}".format(
        training_set.input_vocabulary_size))
    logger.info("  Most common input words: {}".format(
        training_set.input_vocabulary.most_common(5)))
    logger.info("  Output vocabulary size training set: {}".format(
        training_set.target_vocabulary_size))
    logger.info("  Most common target words: {}".format(
        training_set.target_vocabulary.most_common(5)))

    if generate_vocabularies:
        training_set.save_vocabularies(input_vocab_path, target_vocab_path)
        logger.info(
            "Saved vocabularies to {} for input and {} for target.".format(
                input_vocab_path, target_vocab_path))

    logger.info("Loading Test set...")
    test_set = GroundedScanDataset(
        data_path,
        data_directory,
        split="test",  # TODO: use dev set here
        input_vocabulary_file=input_vocab_path,
        target_vocabulary_file=target_vocab_path,
        generate_vocabulary=False)
    test_set.read_dataset(
        max_examples=None,
        simple_situation_representation=simple_situation_representation)

    # Shuffle the test set to make sure that if we only evaluate max_testing_examples we get a random part of the set.
    test_set.shuffle_data()
    logger.info("Done Loading Test set.")

    model = Model(input_vocabulary_size=training_set.input_vocabulary_size,
                  target_vocabulary_size=training_set.target_vocabulary_size,
                  num_cnn_channels=training_set.image_channels,
                  input_padding_idx=training_set.input_vocabulary.pad_idx,
                  target_pad_idx=training_set.target_vocabulary.pad_idx,
                  target_eos_idx=training_set.target_vocabulary.eos_idx,
                  **cfg)
    model = model.cuda() if use_cuda else model
    log_parameters(model)
    trainable_parameters = [
        parameter for parameter in model.parameters()
        if parameter.requires_grad
    ]
    optimizer = torch.optim.Adam(trainable_parameters,
                                 lr=learning_rate,
                                 betas=(adam_beta_1, adam_beta_2))
    scheduler = LambdaLR(optimizer,
                         lr_lambda=lambda t: lr_decay**(t / lr_decay_steps))

    # Load model and vocabularies if resuming.
    start_iteration = 1
    best_iteration = 1
    best_accuracy = 0
    best_exact_match = 0
    best_loss = float('inf')
    if resume_from_file:
        assert os.path.isfile(
            resume_from_file), "No checkpoint found at {}".format(
                resume_from_file)
        logger.info(
            "Loading checkpoint from file at '{}'".format(resume_from_file))
        optimizer_state_dict = model.load_model(resume_from_file)
        optimizer.load_state_dict(optimizer_state_dict)
        start_iteration = model.trained_iterations
        logger.info("Loaded checkpoint '{}' (iter {})".format(
            resume_from_file, start_iteration))

    logger.info("Training starts..")
    training_iteration = start_iteration
    while training_iteration < max_training_iterations:

        # Shuffle the dataset and loop over it.
        training_set.shuffle_data()
        for (input_batch, input_lengths, _, situation_batch, _, target_batch,
             target_lengths, agent_positions,
             target_positions) in training_set.get_data_iterator(
                 batch_size=training_batch_size):
            is_best = False
            model.train()

            # Forward pass.
            target_scores, target_position_scores = model(
                commands_input=input_batch,
                commands_lengths=input_lengths,
                situations_input=situation_batch,
                target_batch=target_batch,
                target_lengths=target_lengths)
            loss = model.get_loss(target_scores, target_batch)
            if auxiliary_task:
                target_loss = model.get_auxiliary_loss(target_position_scores,
                                                       target_positions)
            else:
                target_loss = 0
            loss += weight_target_loss * target_loss

            # Backward pass and update model parameters.
            loss.backward()
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
            model.update_state(is_best=is_best)

            # Print current metrics.
            if training_iteration % print_every == 0:
                accuracy, exact_match = model.get_metrics(
                    target_scores, target_batch)
                if auxiliary_task:
                    auxiliary_accuracy_target = model.get_auxiliary_accuracy(
                        target_position_scores, target_positions)
                else:
                    auxiliary_accuracy_target = 0.
                learning_rate = scheduler.get_lr()[0]
                logger.info(
                    "Iteration %08d, loss %8.4f, accuracy %5.2f, exact match %5.2f, learning_rate %.5f,"
                    " aux. accuracy target pos %5.2f" %
                    (training_iteration, loss, accuracy, exact_match,
                     learning_rate, auxiliary_accuracy_target))

            # Evaluate on test set.
            if training_iteration % evaluate_every == 0:
                with torch.no_grad():
                    model.eval()
                    logger.info("Evaluating..")
                    accuracy, exact_match, target_accuracy = evaluate(
                        test_set.get_data_iterator(batch_size=1),
                        model=model,
                        max_decoding_steps=max_decoding_steps,
                        pad_idx=test_set.target_vocabulary.pad_idx,
                        sos_idx=test_set.target_vocabulary.sos_idx,
                        eos_idx=test_set.target_vocabulary.eos_idx,
                        max_examples_to_evaluate=kwargs["max_testing_examples"]
                    )
                    logger.info(
                        "  Evaluation Accuracy: %5.2f Exact Match: %5.2f "
                        " Target Accuracy: %5.2f" %
                        (accuracy, exact_match, target_accuracy))
                    if exact_match > best_exact_match:
                        is_best = True
                        best_accuracy = accuracy
                        best_exact_match = exact_match
                        model.update_state(accuracy=accuracy,
                                           exact_match=exact_match,
                                           is_best=is_best)
                    file_name = "checkpoint.pth.tar".format(
                        str(training_iteration))
                    if is_best:
                        model.save_checkpoint(
                            file_name=file_name,
                            is_best=is_best,
                            optimizer_state_dict=optimizer.state_dict())

            training_iteration += 1
            if training_iteration > max_training_iterations:
                break
    logger.info("Finished training.")
def train(args, data_loader):
    device = torch.device('cuda:0')
    torch.cuda.set_device(device)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    stream = torch.cuda.Stream(device)

    def to_device(args):
        x1, x2 = args
        with torch.cuda.stream(stream):
            x1 = x1.to(device, non_blocking=True)
            x2 = x2.to(device, non_blocking=True)
        return x1, x2

    train_data = DataQueue(data_loader, max_queue_size=30, nb_worker=10)
    train_loader = Prefetcher(train_data,
                              postprocess=to_device,
                              buffer_size=1,
                              stream=stream)

    model_cfg = ConfigParser()
    model_cfg.read(args.conf)
    max_grad_value = model_cfg.getfloat(args.model_type, "max_grad_value")
    max_grad_norm = model_cfg.getfloat(args.model_type, "max_grad_norm")

    beta = args.beta
    model_q = build_model(args.conf,
                          model_type=args.model_type,
                          write_back=True)
    model_k = build_model(args.conf,
                          model_type=args.model_type,
                          write_back=True)
    momentum_update(model_q, model_k, 1 - beta)

    embedding_size = model_cfg.getint(args.model_type, "embedding_size")
    memory = MemoryMoCo(embedding_size, args.mem_queue_size)

    if torch.cuda.device_count() > 1:
        model_q = nn.DataParallel(model_q, dim=0)
        model_k = nn.DataParallel(model_k, dim=0)
    model_q.to(device)
    model_k.to(device)
    memory.to(device)

    num_epochs = args.num_epochs
    warmup_lr = args.warmup_lr
    initial_lr = args.initial_lr
    final_lr = args.final_lr

    lr_anneal = (final_lr / initial_lr)**(1. / num_epochs)
    lr_decline = (initial_lr - final_lr) / num_epochs

    # Set up learning rate scheduler
    def get_learning_rate(epoch):
        """Compute learning rate of given epoch.

        Users can design different strategy to alter learning rate,
        Please make sure global variables like final_lr、
        lr_decline、initial_lr are assigned before.
        """
        if args.linear_decay:
            this_lr = max(final_lr, initial_lr - lr_decline * epoch)
        else:
            this_lr = max(final_lr, initial_lr * lr_anneal**epoch)

        if epoch == 0 and warmup_lr > 0:  # use warmup lr instead
            this_lr = warmup_lr
        return this_lr

    momentum = 0.9
    nesterov = False
    weight_decay = 1e-5

    # Optimizer
    optimizer = torch.optim.SGD(model_q.parameters(),
                                lr=initial_lr,
                                momentum=momentum,
                                nesterov=nesterov,
                                weight_decay=weight_decay)

    lr_lambda = lambda epoch: get_learning_rate(epoch) / initial_lr
    lr_scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda)

    model_dir = args.model_dir
    log_dir = os.path.join(model_dir, f'log/dist{torch.cuda.device_count()}')
    writer = SummaryWriter(log_dir)

    # Train the Model
    checkpoint_period = args.checkpoint_period
    for epoch in range(args.start_epoch, num_epochs):

        if args.exit_epoch > 0 and args.exit_epoch == epoch:
            break

        ckpt_filename = os.path.join(model_dir,
                                     f'checkpoint_e{epoch-1:03d}.pkl')
        if epoch == args.start_epoch and os.path.isfile(ckpt_filename):
            ckpt = load_checkpoint(model_q, ckpt_filename, map_location='cpu')
            if args.start_epoch > 0:
                if isinstance(ckpt, dict) and 'optimizer' in ckpt:
                    optimizer.load_state_dict(ckpt['optimizer'])
                    print(f'load optimizer states from {ckpt_filename}')
                if isinstance(
                        ckpt, dict
                ) and 'meta' in ckpt and 'lr_state' in ckpt['meta']:
                    lr_state = ckpt['meta'].get('lr_state')
                    lr_scheduler.load_state_dict(lr_state)
                    print(f'load scheduler states from {ckpt_filename}')
            print(f'load model from {ckpt_filename}')

        if epoch == args.start_epoch:
            nb_samples = epoch * args.frames_per_epoch

        writer.add_scalar('train/lr', lr_scheduler.get_lr()[0], nb_samples)

        total_loss = 0
        acc_samples = 0
        total_processed = 0
        steps = 0
        target_samples = checkpoint_period
        while total_processed < args.frames_per_epoch:
            inputs, dis_inputs = train_loader.get()
            inputs = inputs.to(device)
            dis_inputs = dis_inputs.to(device)

            b, t = inputs.size()[:2]

            # Forward + Backward + Optimize
            optimizer.zero_grad()  # zero the gradient buffer
            with torch.no_grad():
                # Shuffle BN
                shf_ids, rev_ids = get_shuffle_ids(b, device)
                dis_inputs = dis_inputs[shf_ids]
                key = model_k(dis_inputs)[rev_ids].detach()
            query = model_q(inputs)

            loss = memory(query, key)
            loss.backward()

            if max_grad_value > 0:
                cur_max_value = max_grad_value
                clip_grad_value_(model_q.parameters(),
                                 clip_value=cur_max_value)
            if max_grad_norm > 0:
                cur_max_norm = max_grad_norm
                norm = clip_grad_norm_(model_q.parameters(),
                                       max_norm=cur_max_norm)
                if norm > cur_max_norm:
                    print(
                        "grad norm {0:.2f} exceeds {1:.2f}, clip to {1:.2f}.".
                        format(norm, cur_max_norm))

            optimizer.step()
            momentum_update(model_q, model_k, beta)
            memory.update(key)

            loss_val = loss.item()
            del loss, key, query, inputs, dis_inputs

            total_processed += b * t
            steps += 1
            nb_samples += b * t
            writer.add_scalar('train/loss', loss_val, nb_samples)
            total_loss += loss_val * b
            acc_samples += b
            if steps % 10 == 0:
                print('Epoch [%d/%d], Step [%d/%d], Loss: %.4f' %
                      (epoch + 1, num_epochs, total_processed,
                       args.frames_per_epoch, total_loss / acc_samples))
                total_loss = 0
                acc_samples = 0

            if checkpoint_period > 0 and total_processed >= target_samples:
                target_samples += checkpoint_period
                ckpt_filename = os.path.join(
                    model_dir,
                    'checkpoint_s{:06d}M.pkl'.format(nb_samples // 1000000))
                meta = {}
                meta['lr_state'] = lr_scheduler.state_dict()
                save_checkpoint(model_q,
                                ckpt_filename,
                                optimizer=optimizer,
                                meta=meta)

        lr_scheduler.step()
        ckpt_filename = os.path.join(
            model_dir, 'checkpoint_s{:06d}M.pkl'.format(nb_samples // 1000000))
        meta = {}
        meta['lr_state'] = lr_scheduler.state_dict()
        save_checkpoint(model_q, ckpt_filename, optimizer=optimizer, meta=meta)

        ckpt_linkname = os.path.join(model_dir,
                                     'checkpoint_e{:03d}.pkl'.format(epoch))
        cmd = "ln -sf ./checkpoint_s{:06d}M.pkl {}"
        cmd = cmd.format(nb_samples // 1000000, ckpt_linkname)
        subprocess.call(cmd, shell=True)
Beispiel #15
0
def main(args: argparse.Namespace):
    logger = CompleteLogger(args.log, args.phase)

    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')

    cudnn.benchmark = True

    # Data loading code
    normalize = T.Normalize(mean=[0.485, 0.456, 0.406],
                            std=[0.229, 0.224, 0.225])
    train_transform = T.Compose([T.Resize(128), T.ToTensor(), normalize])
    val_transform = T.Compose([T.Resize(128), T.ToTensor(), normalize])

    dataset = datasets.__dict__[args.data]
    train_source_dataset = dataset(root=args.root,
                                   task=args.source,
                                   split='train',
                                   download=True,
                                   transform=train_transform)
    train_source_loader = DataLoader(train_source_dataset,
                                     batch_size=args.batch_size,
                                     shuffle=True,
                                     num_workers=args.workers,
                                     drop_last=True)
    train_target_dataset = dataset(root=args.root,
                                   task=args.target,
                                   split='train',
                                   download=True,
                                   transform=train_transform)
    train_target_loader = DataLoader(train_target_dataset,
                                     batch_size=args.batch_size,
                                     shuffle=True,
                                     num_workers=args.workers,
                                     drop_last=True)
    val_dataset = dataset(root=args.root,
                          task=args.target,
                          split='test',
                          download=True,
                          transform=val_transform)
    val_loader = DataLoader(val_dataset,
                            batch_size=args.batch_size,
                            shuffle=False,
                            num_workers=args.workers)

    train_source_iter = ForeverDataIterator(train_source_loader)
    train_target_iter = ForeverDataIterator(train_target_loader)

    # create model
    print("=> using pre-trained model '{}'".format(args.arch))
    backbone = models.__dict__[args.arch](pretrained=True)
    num_factors = train_source_dataset.num_factors
    regressor = Regressor(backbone=backbone,
                          num_factors=num_factors).to(device)

    # define optimizer and lr scheduler
    optimizer = SGD(regressor.get_parameters(),
                    args.lr,
                    momentum=args.momentum,
                    weight_decay=args.wd,
                    nesterov=True)
    lr_scheduler = LambdaLR(
        optimizer, lambda x: args.lr *
        (1. + args.lr_gamma * float(x))**(-args.lr_decay))

    if args.phase == 'test':
        regressor.load_state_dict(
            torch.load(logger.get_checkpoint_path('best')))
        mae = validate(val_loader, regressor, args,
                       train_source_dataset.factors)
        print(mae)
        return

    # start training
    best_mae = 100000.
    for epoch in range(args.epochs):
        # train for one epoch
        print("lr", lr_scheduler.get_lr())
        train(train_source_iter, train_target_iter, regressor, optimizer,
              lr_scheduler, epoch, args)

        # evaluate on validation set
        mae = validate(val_loader, regressor, args,
                       train_source_dataset.factors)

        # remember best mae and save checkpoint
        torch.save(regressor.state_dict(),
                   logger.get_checkpoint_path('latest'))
        if mae < best_mae:
            shutil.copy(logger.get_checkpoint_path('latest'),
                        logger.get_checkpoint_path('best'))
        best_mae = min(mae, best_mae)
        print("mean MAE {:6.3f} best MAE {:6.3f}".format(mae, best_mae))

    print("best_mae = {:6.3f}".format(best_mae))

    logger.close()
Beispiel #16
0
class Trainer(object):
    """
    Trainer encapsulates all the logic necessary for
    training the Recurrent Attention Model.

    All hyperparameters are provided by the user in the
    config file.
    """
    def __init__(self, config, data_loader):
        """
        Construct a new Trainer instance.

        Args
        ----
        - config: object containing command line arguments.
        - data_loader: data iterator
        """
        self.config = config

        # glimpse network params
        self.patch_size = config.patch_size
        self.glimpse_scale = config.glimpse_scale
        self.num_patches = config.num_patches
        self.loc_hidden = config.loc_hidden
        self.glimpse_hidden = config.glimpse_hidden

        # core network params
        self.num_glimpses = config.num_glimpses
        self.hidden_size = config.hidden_size

        # reinforce params
        self.std = config.std
        self.M = config.M

        # data params
        if config.is_train:
            self.train_loader = data_loader[0]
            self.valid_loader = data_loader[1]
            self.num_train = len(self.train_loader.sampler.indices)
            self.num_valid = len(self.valid_loader.sampler.indices)
        else:
            self.test_loader = data_loader
            self.num_test = len(self.test_loader.dataset)
        self.num_classes = 10
        self.num_channels = 1

        # training params
        self.epochs = config.epochs
        self.start_epoch = 0
        self.momentum = config.momentum
        self.lr = config.init_lr

        # misc params
        self.no_tqdm = config.no_tqdm
        self.use_gpu = config.use_gpu
        self.best = config.best
        self.ckpt_dir = config.ckpt_dir
        self.logs_dir = config.logs_dir
        self.best_valid_acc = 0.
        self.counter = 0
        self.lr_patience = config.lr_patience
        self.train_patience = config.train_patience
        self.use_tensorboard = config.use_tensorboard
        self.resume = config.resume
        self.print_freq = config.print_freq
        self.plot_freq = config.plot_freq
        self.model_name = 'ram_{}_{}x{}_{}'.format(config.num_glimpses,
                                                   config.patch_size,
                                                   config.patch_size,
                                                   config.glimpse_scale)

        if config.uncertainty == True:
            self.model_name += '_uncertainty_1'
        else:
            self.model_name += '_uncertainty_0'
        if config.intrinsic == True:
            self.model_name += '_intrinsic_1'
        else:
            self.model_name += '_intrinsic_0'

        self.plot_dir = './plots/' + self.model_name + '/'
        if not os.path.exists(self.plot_dir):
            os.makedirs(self.plot_dir)

        # configure tensorboard logging
        if self.use_tensorboard:
            tensorboard_dir = self.logs_dir + self.model_name
            print('[*] Saving tensorboard logs to {}'.format(tensorboard_dir))
            if not os.path.exists(tensorboard_dir):
                os.makedirs(tensorboard_dir)
            configure(tensorboard_dir)

        # build RAM model
        self.model = RecurrentAttention(self.patch_size, self.num_patches,
                                        self.glimpse_scale, self.num_channels,
                                        self.loc_hidden, self.glimpse_hidden,
                                        self.std, self.hidden_size,
                                        self.num_classes, self.config)
        if self.use_gpu:
            self.model.cuda()

        self.dtypeFloat = (torch.cuda.FloatTensor
                           if self.use_gpu else torch.FloatTensor)
        self.dtypeLong = (torch.cuda.LongTensor
                          if self.use_gpu else torch.LongTensor)

        print('[*] Number of model parameters: {:,}'.format(
            sum([p.data.nelement() for p in self.model.parameters()])))

        # # initialize optimizer and scheduler
        self.optimizer = optim.Adam(
            self.model.parameters(),
            lr=self.config.init_lr,
        )
        lambda_of_lr = lambda epoch: 0.95**epoch
        self.scheduler = LambdaLR(self.optimizer, lr_lambda=lambda_of_lr)
        # self.scheduler = StepLR(self.optimizer,step_size=20,gamma=0.1)
        # self.scheduler = ReduceLROnPlateau(
        #     self.optimizer, 'min', patience=self.lr_patience
        # )

    def reset(self):
        """
        Initialize the hidden state of the core network
        and the location vector.

        This is called once every time a new minibatch
        `x` is introduced.
        """
        dtype = (torch.cuda.FloatTensor if self.use_gpu else torch.FloatTensor)

        h_t = torch.zeros(self.batch_size, self.hidden_size)
        h_t = Variable(h_t).type(dtype)

        l_t = torch.Tensor(self.batch_size, 2).uniform_(-1, 1)
        l_t = Variable(l_t).type(dtype)

        return h_t, l_t

    def train(self):
        """
        Train the model on the training set.

        A checkpoint of the model is saved after each epoch
        and if the validation accuracy is improved upon,
        a separate ckpt is created for use on the test set.
        """
        # load the most recent checkpoint
        if self.resume:
            self.load_checkpoint(best=False)

        print(
            "\n[*] Train on {} samples, validate on {} samples, learn rate {}".
            format(self.num_train, self.num_valid, self.scheduler.get_lr()))

        for epoch in range(self.start_epoch, self.epochs):

            print('\nEpoch: {}/{} . lr: {:.4e} '.format(
                epoch + 1, self.epochs,
                self.scheduler.get_lr()[0]))

            # train for 1 epoch
            train_loss, train_acc = self.train_one_epoch(epoch)

            # evaluate on validation set
            valid_loss, valid_acc = self.validate(epoch)

            self.scheduler.step()

            is_best = valid_acc > self.best_valid_acc
            msg1 = "train loss: {:.3f} - train acc: {:.3f} "
            msg2 = "- val loss: {:.3f} - val acc: {:.3f}"
            if is_best:
                self.counter = 0
                msg2 += " [*]"
            msg = msg1 + msg2
            print(msg.format(train_loss, train_acc, valid_loss, valid_acc))

            # check for improvement
            if not is_best:
                self.counter += 1
            if self.counter > self.train_patience:
                print("[!] No improvement in a while, stopping training.")
                return
            self.best_valid_acc = max(valid_acc, self.best_valid_acc)
            self.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'model_state': self.model.state_dict(),
                    'optim_state': self.optimizer.state_dict(),
                    'best_valid_acc': self.best_valid_acc,
                }, is_best)

    def train_one_epoch(self, epoch):
        """
        Train the model for 1 epoch of the training set.

        An epoch corresponds to one full pass through the entire
        training set in successive mini-batches.

        This is used by train() and should not be called manually.
        """
        batch_time = AverageMeter()
        losses = AverageMeter()
        accs = AverageMeter()

        tic = time.time()
        with tqdm(total=self.num_train, disable=self.no_tqdm) as pbar:
            for i, (x, y) in enumerate(self.train_loader):
                if self.config.use_translate:
                    x = translate_function(x, original_dataset=x)
                if self.use_gpu:
                    x, y = x.cuda(), y.cuda()
                x, y = Variable(x), Variable(y)

                plot = False
                if (epoch % self.plot_freq == 0) and (i == 0):
                    plot = True

                # initialize location vector and hidden state
                self.batch_size = x.shape[0]
                h_t, l_t = self.reset()

                # save images
                imgs = []
                imgs.append(x[0:9])

                # extract the glimpses
                locs = []
                log_pi = []
                baselines = []
                all_log_probas = []  # the prediction at each glimpse step
                uncertainities = [
                ]  # the self-uncertainty at each glimpse step
                uncertainities_baseline = [
                ]  # the self-uncertainty at each glimpse step, but this baseline is only used for the loss of training self-uncertainty, which only involves the error network.

                # by default it needs to run `self.num_glimpse` times
                num_glimpses_taken = [
                    self.num_glimpses - 1 for _ in range(self.batch_size)
                ]

                for t in range(self.num_glimpses):

                    # forward pass through model
                    h_t, l_t, b_t, log_probas, p, diff_uncertainty, diff_uncertainty_baseline = self.model(
                        x, l_t, h_t, last=True)

                    # store
                    locs.append(l_t[0:9])
                    baselines.append(b_t)
                    log_pi.append(p)
                    all_log_probas.append(log_probas)
                    uncertainities.append(diff_uncertainty)
                    uncertainities_baseline.append(diff_uncertainty_baseline)

                # convert list to tensors and reshape
                baselines = torch.stack(baselines).transpose(1, 0)
                log_pi = torch.stack(log_pi).transpose(1, 0)
                # if self.config.uncertainty == True:
                if self.config.uncertainty == True:
                    uncertainities = torch.stack(uncertainities).transpose(
                        1, 0)
                    uncertainities_baseline = torch.stack(
                        uncertainities_baseline).transpose(1, 0)
                all_log_probas = torch.stack(all_log_probas).transpose(1, 0)

                # calculate reward
                num_glimpses_taken_indices = torch.LongTensor(
                    num_glimpses_taken).type(self.dtypeLong)
                log_probas = torch.cat([
                    torch.index_select(a, 0, i).unsqueeze(0)
                    for a, i in zip(all_log_probas, num_glimpses_taken_indices)
                ]).squeeze()
                predicted = torch.max(log_probas, 1)[1]
                R = (predicted.detach() == y).float()
                R = R.unsqueeze(1).repeat(1, self.num_glimpses)

                # compute losses for differentiable modules
                num_glimpses_taken = Variable(
                    torch.LongTensor(num_glimpses_taken),
                    requires_grad=False).type(self.dtypeLong)

                # the mask is used to take only the result of the last glimpse
                mask = _sequence_mask(sequence_length=num_glimpses_taken,
                                      max_len=self.num_glimpses)
                loss_action = F.nll_loss(log_probas, y, reduction='none')
                loss_action = torch.mean(loss_action)

                loss_baseline = F.mse_loss(baselines, R, reduction='none')
                loss_baseline = torch.mean(loss_baseline * mask)
                # loss_baseline = torch.mean( loss_baseline  )

                # compute reinforce loss
                # summed over timesteps and averaged across batch
                adjusted_reward = R - baselines.detach()
                loss_reinforce = torch.sum(-log_pi * adjusted_reward * mask,
                                           dim=1)
                loss_reinforce = torch.mean(loss_reinforce, dim=0)

                # sum up into a hybrid loss
                loss = loss_action + loss_baseline + loss_reinforce

                if self.config.uncertainty == True:
                    y_real_value = F.one_hot(
                        y, self.num_classes).float().detach()
                    diff_ = Variable(torch.abs(
                        y_real_value.unsqueeze(1).expand(
                            -1, self.num_glimpses, -1).data -
                        torch.exp(all_log_probas).data),
                                     requires_grad=False)
                    # loss_self_uncertaintiy_baseline = F.mse_loss(uncertainities_baseline, diff_)
                    loss_self_uncertaintiy_baseline = F.mse_loss(
                        uncertainities_baseline, diff_,
                        reduction='none').mean()
                    loss_self_uncertaintiy_baseline = torch.mean(
                        loss_self_uncertaintiy_baseline)

                    loss += loss_self_uncertaintiy_baseline

                if self.config.intrinsic == True:
                    # the intrinsic sparsity belief
                    reg = self.config.lambda_intrinsic
                    intrinsic_term = torch.sum(-(1.0 / self.num_classes) *
                                               log_probas)
                    loss_intrinsic = reg * intrinsic_term
                    loss += loss_intrinsic
                if self.config.uncertainty == True:
                    # the second reinforce loss: minimizing the uncertainty
                    reg = self.config.lambda_uncertainty
                    loss_self_uncertaintiy_minimizing = reg * torch.sum(
                        uncertainities)
                    loss += loss_self_uncertaintiy_minimizing

                # compute accuracy
                correct = (predicted == y).float()
                acc = 100 * (correct.sum() / len(y))

                # store
                losses.update(loss.data, list(x.size())[0])
                accs.update(acc.data, list(x.size())[0])

                # compute gradients and update SGD
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                # measure elapsed time
                toc = time.time()
                batch_time.update(toc - tic)

                if self.no_tqdm is not True:
                    pbar.set_description(
                        ("{:.1f}s - loss: {:.3f} - acc: {:.3f}".format(
                            (toc - tic), loss.data, acc.data)))
                    pbar.update(self.batch_size)

                # dump the glimpses and locs
                if plot:
                    if self.use_gpu:
                        imgs = [g.cpu().data.numpy().squeeze() for g in imgs]
                        locs = [l.cpu().data.numpy() for l in locs]
                    else:
                        imgs = [g.data.numpy().squeeze() for g in imgs]
                        locs = [l.data.numpy() for l in locs]
                    pickle.dump(
                        imgs,
                        open(self.plot_dir + "g_{}.p".format(epoch + 1), "wb"))
                    pickle.dump(
                        locs,
                        open(self.plot_dir + "l_{}.p".format(epoch + 1), "wb"))

                # log to tensorboard
                if self.use_tensorboard:
                    iteration = epoch * len(self.train_loader) + i
                    log_value('train_loss', losses.avg, iteration)
                    log_value('train_acc', accs.avg, iteration)

            return losses.avg, accs.avg

    def validate(self, epoch, M=1):
        """
        Evaluate the model on the validation set.
        """
        losses = AverageMeter()
        accs = AverageMeter()

        for i, (x, y) in enumerate(self.valid_loader):
            if self.config.use_translate:
                x = translate_function(x, original_dataset=x)
            if self.use_gpu:
                x, y = x.cuda(), y.cuda()
            x, y = Variable(x), Variable(y)

            # duplicate M times
            x = x.repeat(M, 1, 1, 1)

            # initialize location vector and hidden state
            self.batch_size = x.shape[0]
            h_t, l_t = self.reset()

            # extract the glimpses
            locs = []
            log_pi = []
            baselines = []
            all_log_probas = []
            uncertainities = []
            uncertainities_baseline = []

            # by default it needs to run `self.num_glimpse` times
            num_glimpses_taken = [
                self.num_glimpses - 1 for _ in range(self.batch_size)
            ]

            for t in range(self.num_glimpses):

                # forward pass through model
                h_t, l_t, b_t, log_probas, p, diff_uncertainty, diff_uncertainty_baseline = self.model(
                    x, l_t, h_t, last=True)

                # store
                locs.append(l_t[0:9])
                baselines.append(b_t)
                log_pi.append(p)
                all_log_probas.append(log_probas)
                uncertainities.append(diff_uncertainty)
                uncertainities_baseline.append(diff_uncertainty_baseline)

            # convert list to tensors and reshape
            baselines = torch.stack(baselines).transpose(1, 0)
            log_pi = torch.stack(log_pi).transpose(1, 0)
            if self.config.uncertainty == True:
                uncertainities = torch.stack(uncertainities).transpose(1, 0)
                uncertainities_baseline = torch.stack(
                    uncertainities_baseline).transpose(1, 0)
            all_log_probas = torch.stack(all_log_probas).transpose(1, 0)

            # calculate reward
            num_glimpses_taken_indices = torch.LongTensor(
                num_glimpses_taken).type(self.dtypeLong)
            log_probas = torch.cat([
                torch.index_select(a, 0, i).unsqueeze(0)
                for a, i in zip(all_log_probas, num_glimpses_taken_indices)
            ]).squeeze()
            # average the `self.M` times of prediction
            log_probas = log_probas.view(M, -1, log_probas.shape[-1])
            log_probas = torch.mean(log_probas, dim=0)
            predicted = torch.max(log_probas, 1)[1]
            R = (predicted.detach() == y).float()
            R = R.unsqueeze(1).repeat(M, self.num_glimpses)

            # compute losses for differentiable modules
            num_glimpses_taken = Variable(torch.LongTensor(num_glimpses_taken),
                                          requires_grad=False).type(
                                              self.dtypeLong)

            mask = _sequence_mask(sequence_length=num_glimpses_taken,
                                  max_len=self.num_glimpses)
            loss_action = F.nll_loss(log_probas, y, reduction='none')
            loss_action = torch.mean(loss_action)

            loss_baseline = F.mse_loss(baselines, R, reduction='none')
            loss_baseline = torch.mean(loss_baseline * mask)

            adjusted_reward = R - baselines.detach()
            loss_reinforce = torch.sum(-log_pi * adjusted_reward * mask, dim=1)
            loss_reinforce = torch.mean(loss_reinforce, dim=0)

            # sum up into a hybrid loss
            loss = loss_action + loss_baseline + loss_reinforce

            if self.config.uncertainty == True:
                y_real_value = F.one_hot(y, self.num_classes).float().detach()
                diff_ = Variable(torch.abs(
                    y_real_value.unsqueeze(1).expand(-1, self.num_glimpses,
                                                     -1).data -
                    torch.exp(all_log_probas).data),
                                 requires_grad=False)

                loss_self_uncertaintiy_baseline = F.mse_loss(
                    uncertainities_baseline, diff_, reduction='none').mean()
                loss_self_uncertaintiy_baseline = torch.mean(
                    loss_self_uncertaintiy_baseline)
                loss += loss_self_uncertaintiy_baseline

            if self.config.intrinsic == True:
                # the intrinsic sparsity belief
                reg = self.config.lambda_intrinsic
                loss_intrinsic = reg * torch.sum(
                    -(1.0 / self.num_classes) * log_probas)
                loss += loss_intrinsic
            if self.config.uncertainty == True:
                # the second reinforce loss: minimizing the uncertainty
                reg = self.config.lambda_uncertainty
                loss_self_uncertaintiy_minimizing = reg * torch.sum(
                    uncertainities)
                loss += loss_self_uncertaintiy_minimizing

            # compute accuracy
            correct = (predicted == y).float()
            acc = 100 * (correct.sum() / len(y))

            # store
            losses.update(loss.data, list(x.size())[0])
            accs.update(acc.data, list(x.size())[0])

            # log to tensorboard
            if self.use_tensorboard:
                iteration = epoch * len(self.valid_loader) + i
                log_value('valid_loss', losses.avg, iteration)
                log_value('valid_acc', accs.avg, iteration)

        return losses.avg, accs.avg

    def test(self):
        """
        Test the model on the held-out test data.
        This function should only be called at the very
        end once the model has finished training.
        """
        correct = 0

        # load the best checkpoint
        self.load_checkpoint(best=self.best)

        self.num_test = len(self.test_loader.sampler)

        all_num_glimpses_taken = []
        for i, (x, y) in enumerate(self.test_loader):
            torch.manual_seed(self.config.random_seed)
            if self.use_gpu:
                x, y = x.cuda(), y.cuda()
            x, y = Variable(x), Variable(y)

            # duplicate 10 times
            x = x.repeat(self.M, 1, 1, 1)

            # initialize location vector and hidden state
            self.batch_size = x.shape[0]
            h_t, l_t = self.reset()

            # extract the glimpses
            locs = []
            log_pi = []
            baselines = []
            all_log_probas = []
            uncertainities = []

            # by default it needs to run `self.num_glimpse` times
            num_glimpses_taken = [
                self.config.num_glimpses - 1 for _ in range(self.batch_size)
            ]

            for t in range(self.config.num_glimpses):

                # forward pass through model
                h_t, l_t, b_t, log_probas, p, diff_uncertainty, diff_uncertainty_baseline = self.model(
                    x, l_t, h_t, last=True)
                # store
                locs.append(l_t[0:9])
                baselines.append(b_t)
                log_pi.append(p)
                all_log_probas.append(log_probas)
                uncertainities.append(diff_uncertainty)

                if self.config.dynamic == True:
                    # determine if it has achieve a threshold uncertainty
                    probs_data = torch.exp(log_probas).data.tolist()
                    diff_uncertainty_data = diff_uncertainty.data.tolist()
                    for instance_idx, (prediction, uncertainty) in enumerate(
                            zip(probs_data, diff_uncertainty_data)):
                        a_star_idx = max(enumerate(prediction),
                                         key=lambda x: x[1])[0]
                        a_prime_idx = max(
                            [(idx, pred +
                              self.config.exploration_rate * uncertainty[idx])
                             for idx, pred in enumerate(prediction)
                             if idx != a_star_idx],
                            key=lambda x: x[1])[0]
                        a_star_lower_bound = prediction[
                            a_star_idx] - self.config.exploration_rate * uncertainty[
                                a_star_idx]
                        a_prime_upper_bound = prediction[
                            a_prime_idx] - self.config.exploration_rate * uncertainty[
                                a_prime_idx]
                        if a_star_lower_bound >= a_prime_upper_bound:
                            num_glimpses_taken[instance_idx] = t

                    if all([
                            num < self.config.num_glimpses - 1
                            for num in num_glimpses_taken
                    ]):
                        # print(num_glimpses_taken)
                        break
                        # print('strange! end now!:',t)

            # convert list to tensors and reshape
            baselines = torch.stack(baselines).transpose(1, 0)
            log_pi = torch.stack(log_pi).transpose(1, 0)
            if self.config.uncertainty == True or self.config.dynamic == True:
                uncertainities = torch.stack(uncertainities).transpose(1, 0)
            all_log_probas = torch.stack(all_log_probas).transpose(1, 0)

            all_num_glimpses_taken.extend(num_glimpses_taken)

            # calculate reward
            num_glimpses_taken_indices = torch.LongTensor(
                num_glimpses_taken).type(self.dtypeLong)
            log_probas = torch.cat([
                torch.index_select(a, 0, i).unsqueeze(0)
                for a, i in zip(all_log_probas, num_glimpses_taken_indices)
            ]).squeeze()
            # average the `self.M` times of prediction
            log_probas = log_probas.view(self.M, -1, log_probas.shape[-1])
            log_probas = torch.mean(log_probas, dim=0)

            pred = log_probas.data.max(1, keepdim=True)[1]
            correct += pred.eq(y.data.view_as(pred)).cpu().sum()

        perc = (100. * correct) / (self.num_test)
        error = 100 - perc
        print('[*] Test Acc: {}/{} ({:.2f}% - {:.2f}%)'.format(
            correct, self.num_test, perc, error))
        if self.config.dynamic == True:
            print('use dynamic')
            avg_num_glimpses_taken = sum(all_num_glimpses_taken) / len(
                all_num_glimpses_taken) + 1
            return (avg_num_glimpses_taken,
                    1.0 * correct.tolist() / self.num_test)
        return 1.0 * correct.tolist() / self.num_test
        # return perc.tolist()

    def test_for_all(
        self,
        range_all=100,
    ):
        """
        Test the model on the held-out test data.
        This is used to run the model under different number of glimpses
        """
        correct = []
        for _ in range(range_all):
            correct.append(0)

        # load the best checkpoint
        self.load_checkpoint(best=self.best)

        self.num_test = len(self.test_loader.sampler)

        all_num_glimpses_taken = []
        for i, (x, y) in enumerate(tqdm(self.test_loader)):
            torch.manual_seed(self.config.random_seed)
            if self.use_gpu:
                x, y = x.cuda(), y.cuda()
            x, y = Variable(x), Variable(y)

            # duplicate 10 times
            x = x.repeat(self.M, 1, 1, 1)

            # initialize location vector and hidden state
            self.batch_size = x.shape[0]
            h_t, l_t = self.reset()

            # extract the glimpses
            locs = []
            log_pi = []
            baselines = []
            all_log_probas = []
            uncertainities = []

            # by default it needs to run `self.num_glimpse` times
            num_glimpses_taken = [
                range_all - 1 for _ in range(self.batch_size)
            ]

            for t in range(self.config.num_glimpses):

                # forward pass through model
                h_t, l_t, b_t, log_probas, p, diff_uncertainty, diff_uncertainty_baseline = self.model(
                    x, l_t, h_t, last=True)
                # store
                locs.append(l_t[0:9])
                baselines.append(b_t)
                log_pi.append(p)
                all_log_probas.append(log_probas)
                uncertainities.append(diff_uncertainty)

                if self.config.dynamic == True:
                    # determine if it has achieve a threshold uncertainty
                    probs_data = torch.exp(log_probas).data.tolist()
                    diff_uncertainty_data = diff_uncertainty.data.tolist()
                    for instance_idx, (prediction, uncertainty) in enumerate(
                            zip(probs_data, diff_uncertainty_data)):
                        a_star_idx = max(enumerate(prediction),
                                         key=lambda x: x[1])[0]
                        a_prime_idx = max(
                            [(idx, pred +
                              self.config.exploration_rate * uncertainty[idx])
                             for idx, pred in enumerate(prediction)
                             if idx != a_star_idx],
                            key=lambda x: x[1])[0]
                        a_star_lower_bound = prediction[
                            a_star_idx] - self.config.exploration_rate * uncertainty[
                                a_star_idx]
                        a_prime_upper_bound = prediction[
                            a_prime_idx] - self.config.exploration_rate * uncertainty[
                                a_prime_idx]
                        if a_star_lower_bound >= a_prime_upper_bound:
                            num_glimpses_taken[instance_idx] = t

                    if all([
                            num < self.config.num_glimpses - 1
                            for num in num_glimpses_taken
                    ]):
                        # print(num_glimpses_taken)
                        break

            # convert list to tensors and reshape
            baselines = torch.stack(baselines).transpose(1, 0)
            log_pi = torch.stack(log_pi).transpose(1, 0)
            if self.config.uncertainty == True or self.config.dynamic == True:
                uncertainities = torch.stack(uncertainities).transpose(1, 0)
            all_log_probas = torch.stack(all_log_probas).transpose(1, 0)

            all_num_glimpses_taken.extend(num_glimpses_taken)

            # calculate reward
            for num in range(range_all):
                num_glimpses_taken = [num for _ in range(self.batch_size)]
                num_glimpses_taken_indices = torch.LongTensor(
                    num_glimpses_taken).type(self.dtypeLong)
                # log_probas = torch.cat([ torch.index_select(a, 0, i).unsqueeze(0) for a, i in zip(all_log_probas, num_glimpses_taken_indices) ]).squeeze()

                log_probas = all_log_probas[:, num]
                # print(all_log_probas.size(),log_probas.size())
                # average the `self.M` times of prediction
                log_probas = log_probas.view(self.M, -1, log_probas.shape[-1])
                log_probas = torch.mean(log_probas, dim=0)

                pred = log_probas.data.max(1, keepdim=True)[1]
                correct[num] += pred.eq(y.data.view_as(pred)).cpu().sum()

        return [1.0 * cor.tolist() / self.num_test for cor in correct]

        # return 1.0 * correct.tolist() / self.num_test

    def save_checkpoint(self, state, is_best):
        """
        Save a copy of the model so that it can be loaded at a future
        date. This function is used when the model is being evaluated
        on the test data.

        If this model has reached the best validation accuracy thus
        far, a seperate file with the suffix `best` is created.
        """
        # print("[*] Saving model to {}".format(self.ckpt_dir))

        filename = self.model_name + '_ckpt.pth.tar'
        ckpt_path = os.path.join(self.ckpt_dir, filename)
        torch.save(state, ckpt_path)

        if is_best:
            filename = self.model_name + '_model_best.pth.tar'
            shutil.copyfile(ckpt_path, os.path.join(self.ckpt_dir, filename))

    def load_checkpoint(self, best=False):
        """
        Load the best copy of a model. This is useful for 2 cases:

        - Resuming training with the most recent model checkpoint.
        - Loading the best validation model to evaluate on the test data.

        Params
        ------
        - best: if set to True, loads the best model. Use this if you want
          to evaluate your model on the test data. Else, set to False in
          which case the most recent version of the checkpoint is used.
        """
        print("[*] Loading model from {}".format(self.ckpt_dir))

        filename = self.model_name + '_ckpt.pth.tar'
        if best:
            filename = self.model_name + '_model_best.pth.tar'
        ckpt_path = os.path.join(self.ckpt_dir, filename)
        ckpt = torch.load(ckpt_path)

        # load variables from checkpoint
        self.start_epoch = ckpt['epoch']
        self.best_valid_acc = ckpt['best_valid_acc']
        self.model.load_state_dict(ckpt['model_state'])
        self.optimizer.load_state_dict(ckpt['optim_state'])

        if best:
            print("[*] Loaded {} checkpoint @ epoch {} "
                  "with best valid acc of {:.3f}".format(
                      filename, ckpt['epoch'], ckpt['best_valid_acc']))
        else:
            print("[*] Loaded {} checkpoint @ epoch {}".format(
                filename, ckpt['epoch']))
def train(mode='train',
          train_path='train.conllx',
          model='dozat',
          dataset='conllx',
          dev_path='dev.conllx',
          test_path='test.conllx',
          ud=True,
          output_dir='output',
          emb_dim=0,
          char_emb_dim=0,
          char_model=None,
          tagger=None,
          batch_size=5000,
          n_iters=10,
          dropout_p=0.33,
          num_layers=1,
          print_every=1,
          eval_every=100,
          bi=True,
          var_drop=False,
          upos_pred=False,
          lr=0.001,
          adam_beta1=0.9,
          adam_beta2=0.999,
          weight_decay=0.,
          plateau=False,
          resume=False,
          lr_decay=1.0,
          lr_decay_steps=5000,
          clip=5.,
          momentum=0,
          optimizer='adam',
          glove=True,
          seed=42,
          dim=0,
          window_size=0,
          num_filters=0,
          **kwargs):

    device = torch.device(type='cuda') if use_cuda else torch.device(
        type='cpu')

    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    cfg = locals().copy()

    torch.manual_seed(seed)
    np.random.seed(seed)

    # load data component
    if dataset == "conllx":
        dataset_obj = ConllXDataset
        fields = get_data_fields()
        _upos = None
        ud = False
    elif dataset == "conllu":
        dataset_obj = ConllUDataset
        fields = get_data_fields_conllu()
        _upos = fields['upos'][-1]
        ud = True
    else:
        raise NotImplementedError()

    _form = fields['form'][-1]
    _pos = fields['pos'][-1]
    _chars = fields['chars'][-1]

    train_dataset = dataset_obj(train_path, fields)
    dev_dataset = dataset_obj(dev_path, fields)
    test_dataset = dataset_obj(test_path, fields)

    logger.info("Loaded %d train examples" % len(train_dataset))
    logger.info("Number of train tokens: %d" % train_dataset.n_tokens)
    logger.info("Loaded %d dev examples" % len(dev_dataset))
    logger.info("Number of train tokens: %d" % dev_dataset.n_tokens)
    logger.info("Loaded %d test examples" % len(test_dataset))
    logger.info("Number of train tokens: %d" % test_dataset.n_tokens)

    form_vocab_path = os.path.join(output_dir, 'vocab.form.pth.tar')
    pos_vocab_path = os.path.join(output_dir, 'vocab.pos.pth.tar')
    char_vocab_path = os.path.join(output_dir, 'vocab.char.pth.tar')

    if not resume:
        # build vocabularies
        # words have a min frequency of 2 to be included; others become <unk>
        # words without a Glove vector are initialized ~ N(0, 0.5) mimicking Glove

        # Note: this requires the latest torchtext development version from Github.
        # - git clone https://github.com/pytorch/text.git torchtext
        # - cd torchtext
        # - python setup.py build
        # - python setup.py install

        def unk_init(x):
            # return 0.01 * torch.randn(x)
            return torch.zeros(x)

        if glove:
            logger.info("Using Glove vectors")
            glove_vectors = GloVe(name='6B', dim=100)
            _form.build_vocab(train_dataset,
                              min_freq=2,
                              unk_init=unk_init,
                              vectors=glove_vectors)
            n_unks = 0
            unk_set = set()
            # for now, set UNK words manually
            # (torchtext does not seem to support it yet)
            for i, token in enumerate(_form.vocab.itos):
                if token not in glove_vectors.stoi:
                    n_unks += 1
                    unk_set.add(token)
                    _form.vocab.vectors[i] = unk_init(emb_dim)
            # print(n_unks, unk_set)

        else:
            _form.build_vocab(train_dataset, min_freq=2)

        _pos.build_vocab(train_dataset)
        if ud:
            _upos.build_vocab(train_dataset)
        _chars.build_vocab(train_dataset)

        # save vocabularies
        torch.save(_form.vocab, form_vocab_path)
        torch.save(_pos.vocab, pos_vocab_path)
        torch.save(_chars.vocab, char_vocab_path)

    else:
        # load vocabularies
        _form.vocab = torch.load(form_vocab_path)
        _pos.vocab = torch.load(pos_vocab_path)
        _chars.vocab = torch.load(char_vocab_path)

    print("First 10 vocabulary entries, words: ",
          " ".join(_form.vocab.itos[:10]))
    print("First 10 vocabulary entries, pos tags: ",
          " ".join(_pos.vocab.itos[:10]))
    print("First 10 vocabulary entries, chars: ",
          " ".join(_chars.vocab.itos[:10]))
    if upos_pred:
        print("First 10 vocabulary entries, upos tags: ",
              " ".join(_upos.vocab.itos[:10]))

    n_words = len(_form.vocab)
    n_tags = len(_pos.vocab)
    if upos_pred:
        n_utags = len(_upos.vocab)
    else:
        n_utags = 0
    n_chars = len(_chars.vocab)

    def batch_size_fn(new, count, sofar):
        return len(new.form) + 1 + sofar

    # iterators
    train_iter = Iterator(train_dataset,
                          batch_size,
                          train=True,
                          sort_within_batch=True,
                          batch_size_fn=batch_size_fn,
                          device=device)
    dev_iter = Iterator(dev_dataset,
                        32,
                        train=False,
                        sort_within_batch=True,
                        device=device)
    test_iter = Iterator(test_dataset,
                         32,
                         train=False,
                         sort_within_batch=True,
                         device=device)

    # if n_iters or eval_every are negative, we set them to that many
    # number of epochs
    iters_per_epoch = (len(train_dataset) // batch_size) + 1
    if eval_every < 0:
        logger.info("Setting eval_every to %d epoch(s) = %d iters" %
                    (-1 * eval_every, -1 * eval_every * iters_per_epoch))
        eval_every = iters_per_epoch * eval_every

    if n_iters < 0:
        logger.info("Setting n_iters to %d epoch(s) = %d iters" %
                    (-1 * n_iters, -1 * n_iters * iters_per_epoch))
        n_iters = -1 * n_iters * iters_per_epoch

    # load up the model
    if upos_pred:
        upos_vocab = _upos.vocab
    else:
        upos_vocab = None
    model = Tagger(n_words=n_words,
                   n_tags=n_tags,
                   n_utags=n_utags,
                   n_chars=n_chars,
                   form_vocab=_form.vocab,
                   char_vocab=_chars.vocab,
                   pos_vocab=_pos.vocab,
                   upos_vocab=upos_vocab,
                   **cfg)

    # set word vectors
    if glove:
        _form.vocab.vectors = _form.vocab.vectors / torch.std(
            _form.vocab.vectors)
        # print(torch.std(_form.vocab.vectors))
        model.encoder.embedding.weight.data.copy_(_form.vocab.vectors)
        model.encoder.embedding.weight.requires_grad = True

    model = model.cuda() if use_cuda else model

    start_iter = 1
    best_iter = 0
    best_pos_acc = -1.
    test_pos_acc = -1.

    # optimizer and learning rate scheduler
    trainable_parameters = [p for p in model.parameters() if p.requires_grad]
    if optimizer == 'sgd':
        optimizer = torch.optim.SGD(trainable_parameters,
                                    lr=lr,
                                    momentum=momentum)
    else:
        optimizer = torch.optim.Adam(trainable_parameters,
                                     lr=lr,
                                     betas=(adam_beta1, adam_beta2))

    # learning rate schedulers
    if not plateau:
        scheduler = LambdaLR(
            optimizer, lr_lambda=lambda t: lr_decay**(t / lr_decay_steps))
    else:
        scheduler = ReduceLROnPlateau(optimizer,
                                      mode='max',
                                      factor=0.75,
                                      patience=5,
                                      min_lr=1e-4)

    # load model and vocabularies if resuming
    if resume:
        if os.path.isfile(resume):
            print("=> loading checkpoint '{}'".format(resume))
            checkpoint = torch.load(resume)
            start_iter = checkpoint['iter_i']
            best_pos_acc = checkpoint['best_pos_acc']
            test_pos_acc = checkpoint['test_pos_acc']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (iter {})".format(
                resume, checkpoint['iter_i']))
        else:
            print("=> no checkpoint found at '{}'".format(resume))

    print_parameters(model)

    # print some stuff just for fun
    logger.info("Most common words: %s" % _form.vocab.freqs.most_common(20))
    logger.info("Word vocab size: %s" % n_words)
    logger.info("Most common XPOS-tags: %s" % _pos.vocab.freqs.most_common())
    logger.info("POS vocab size: %s" % n_tags)
    # logger.info("Most common chars: %s" % _chars.nesting_field.vocab.freqs.most_common())
    logger.info("Chars vocab size: %s" % n_chars)

    print("First training example:")
    print_example(train_dataset[0])

    print("First dev example:")
    print_example(dev_dataset[0])

    print("First test example:")
    print_example(test_dataset[0])

    logger.info("Training starts..")
    upos_var, morph_var = None, None
    for iter_i in range(start_iter, n_iters + 1):

        if not ud:
            epoch_done = (train_dataset.n_tokens // batch_size)
        else:
            epoch_done = (train_dataset.n_tokens // batch_size)

        # if not plateau and iter_i % epoch_done == 0:  # TODO: fix
        #   scheduler.step()
        scheduler.step()
        model.train()

        batch = next(iter(train_iter))
        form_var, lengths = batch.form

        pos_var, pos_lengths = batch.pos
        if upos_pred:
            upos_var, _ = batch.upos
        else:
            upos_var = None

        char_var, sentence_lengths, word_lengths = batch.chars
        lengths = lengths.view(-1).tolist()

        result = model(form_var=form_var,
                       char_var=char_var,
                       pos_var=pos_var,
                       lengths=lengths,
                       word_lengths=word_lengths,
                       pos_lengths=pos_lengths)

        if upos_pred:
            targets = dict(pos=batch.pos, upos=batch.upos)
        else:
            targets = dict(pos=batch.pos, upos=None)

        all_losses = model.get_loss(scores=result, targets=targets)

        loss = all_losses['loss']

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)

        optimizer.step()
        optimizer.zero_grad()

        if iter_i % print_every == 0:

            # get scores for this batch
            upos_predictions = []
            if model.tagger == "linear" or model.tagger == "mlp":
                if model.upos_pred:
                    upos_predictions = result['output']['upos'].max(2)[1]
                    pos_predictions = result['output']['xpos'].max(2)[1]
                else:
                    pos_predictions = result['output']['xpos'].max(2)[1]
            else:
                pos_predictions = result['sequence']

            predictions = dict(pos=pos_predictions, upos=upos_predictions)
            if model.upos_pred:
                targets = dict(pos=batch.pos, upos=batch.upos)
            else:
                targets = dict(pos=batch.pos, upos=None)

            pos_acc, upos_acc = model.get_accuracy(predictions=predictions,
                                                   targets=targets)

            if not plateau:
                lr = scheduler.get_lr()[0]
            else:
                lr = [group['lr'] for group in optimizer.param_groups][0]

            fmt = "Iter %08d loss %8.4f pos-acc %5.2f upos-acc %5.2f lr %.5f"

            logger.info(fmt % (iter_i, loss, pos_acc, upos_acc, lr))

        if iter_i % eval_every == 0:

            # parse dev set and save to file for official evaluation
            dev_out_path = 'dev.iter%08d.conll' % iter_i
            dev_out_path = os.path.join(output_dir, dev_out_path)
            predict_and_save(dataset=dev_dataset,
                             model=model,
                             dataset_path=dev_path,
                             out_path=dev_out_path)

            _dev_pos_acc, _dev_upos_acc = get_pos_acc(dev_path, dev_out_path,
                                                      ud)

            logger.info("Evaluation dev Iter %08d "
                        "pos-acc %5.2f upos-acc %5.2f" %
                        (iter_i, _dev_pos_acc, _dev_upos_acc))

            # parse test set and save to file for official evaluation
            test_out_path = 'test.iter%08d.conll' % iter_i
            test_out_path = os.path.join(output_dir, test_out_path)
            predict_and_save(dataset=test_dataset,
                             model=model,
                             dataset_path=test_path,
                             out_path=test_out_path)
            _test_pos_acc, _test_upos_acc = get_pos_acc(
                test_path, test_out_path, ud)

            logger.info("Evaluation test Iter %08d "
                        "pos-acc %5.2f upos-acc %5.2f" %
                        (iter_i, _test_pos_acc, _test_upos_acc))

            if plateau:
                scheduler.step(_dev_pos_acc)

            if _dev_pos_acc > best_pos_acc:
                best_iter = iter_i
                best_pos_acc = _dev_pos_acc
                test_pos_acc = _test_pos_acc
                is_best = True
            else:
                is_best = False

            save_checkpoint(
                output_dir, {
                    'iter_i': iter_i,
                    'state_dict': model.state_dict(),
                    'best_iter': best_iter,
                    'test_pos_acc': test_pos_acc,
                    'optimizer': optimizer.state_dict(),
                }, False)

    logger.info("Done Training")
    logger.info(
        "Best model Iter %08d Dev POS-acc %12.4f Test POS-acc %12.4f " %
        (best_iter, best_pos_acc, test_pos_acc))
Beispiel #18
0
class FNN: 
    def __init__(self):
        ## Device configuration
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    def set_data(self, features, targets, D, denom_sq):
        self.features_np = features
        self.targets_np = targets
        self.D_np = D
        self.inv_denom_sq = denom_sq**-1
    
    def train(self, config):
        ## Internal config
        self.config = {}
        self.config['num_epochs']       = 5000
        self.config['n_hidden']         = 2
        self.config['hidden_size']      = 40
        self.config['batch_size']       = 10
        self.config['lr']               = 1e-2
        self.config['regularization']   = 1e-10
        # Overwrite internal config values given in the external config
        if config:
            for key in config.keys():
                self.config[key] = config[key]
        
        # Assume we're using ray.tune at first
        self.tuning = True
        
        ## Model
        self.config['input_size'] = self.features_np['train'].shape[1]
        self.config['output_size'] = self.targets_np['train'].shape[1]
        self.model = Model(self.config).to(self.device)
        
        ## Data loaders
        self.batch_size = self.config['batch_size']
        self.train_loader = data_loader.create_loader(
            self.features_np['train'],
            self.targets_np['train'],
            self.batch_size,
            True)
        self.validate_loader  = data_loader.create_loader(
            self.features_np['validate'],
            self.targets_np['validate'],
            self.features_np['validate'].shape[0], # use all test samples
            False)                             # don't shuffle
        
        ## Hyperparameters
        self.num_epochs = self.config['num_epochs']
        self.learning_rate = self.config['lr']
        
        ## Loss and optimizer
        self.criterion = self.eps_reg_sq
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate, eps=1e-8, weight_decay=self.config['regularization'])
        lambdaLR = lambda epoch: 1 / (1 + 0.005*epoch)
        self.scheduler = LambdaLR(self.optimizer, lr_lambda=lambdaLR)
        
        self.train_start()
    
    def train_start(self):
        ## Train
        early_stop = False
        self.D = torch.from_numpy(self.D_np).float().to(self.device)
        
        for epoch in range(self.num_epochs):
            for i, (features, targets) in enumerate(self.train_loader):
                self.model.train()
                self.optimizer.zero_grad()
                
                # Move tensors to the configured device
                features = features.to(self.device)
                targets  =  targets.to(self.device)
                
                # Forward pass
                outputs = self.model(features)
                loss = self.criterion(outputs, targets) ** 0.5
                if torch.isnan(loss):
                    print('Something went nan, stopping')
                    early_stop = True
                    break # break out of this batch

                # Backward and optimize
                loss.backward()
                self.optimizer.step()
            
            if early_stop:
                break # break out of this epoch
                
            self.scheduler.step()
                
            if epoch%10==0 or epoch==self.num_epochs-1:
                validate_loss  = self.get_loss(self.validate_loader)
                train_loss = self.get_loss(self.train_loader)
                print('eps_reg: Epoch [{}/{}], LR: {:.2e}, Train loss: {:.2e}, Validate loss: {:.2e}'
                    .format(epoch+1, self.num_epochs, self.scheduler.get_lr()[0], train_loss.item()**0.5, validate_loss.item()**0.5))
                
                if self.tuning:
                    try:
                        tune.track.log(mean_loss = validate_loss.item(), episodes_this_iter = 10)
                    except:
                        self.tuning = False
        return self

    def eps_reg_sq(self, outputs, targets):
        return torch.sum((self.D*(targets - outputs)) ** 2) * self.inv_denom_sq / targets.shape[0]
        
    def get_loss(self, loader):
        with torch.no_grad():
            self.model.eval()
            loss = 0.0
            for features, targets in loader:
                features = features.to(self.device)
                targets = targets.to(self.device)
                outputs = self.model(features)
                loss += self.criterion(outputs, targets)
            return loss/len(loader)

    def evaluate(self, features):
        with torch.no_grad():
            self.model.eval()
            output = self.model(torch.tensor(features).float())
            u_rb = output.numpy()
            return u_rb
    
    def save(self, model_dir, component):
        try:
            path_config     = os.path.join(tune.track.trial_dir(),'config')
            path_state_dict = os.path.join(tune.track.trial_dir(),'state_dict')
        except:
            # not tuning
            path_config     = os.path.join(model_dir, 'FNN', component,'config')
            path_state_dict = os.path.join(model_dir, 'FNN', component,'state_dict')
        with open(path_config, 'wb+') as f:
            pickle.dump(self.config, f)
        
        torch.save(self.model.state_dict(), path_state_dict)
    
    def load(self, model_dir, component):
        '''
        Find and loads the best model from ray.tune analysis results.
        '''
        try:
            path_analysis = os.path.join(model_dir,'FNN',component)
            analysis = tune.Analysis(path_analysis)
            df_temp = analysis.dataframe()
            idx = df_temp['mean_loss'].idxmin()
            logdir = df_temp.loc[idx]['logdir']
            path_config     = os.path.join(logdir,'config')
            path_state_dict = os.path.join(logdir,'state_dict')
        except:
            # no tuning records
            path_config     = os.path.join(model_dir, 'FNN', component,'config')
            path_state_dict = os.path.join(model_dir, 'FNN', component,'state_dict')
            
        
        with open(path_config, 'rb') as f:
            config = pickle.load(f)
            self.model = Model(config).to(self.device)
        
        state_dict = torch.load(path_state_dict,
                                map_location=torch.device('cpu'))
        self.model.load_state_dict(state_dict)
def train(train_data_path: str,
          val_data_paths: dict,
          use_cuda: bool,
          model_name: str,
          is_baseline: bool,
          resume_from_file=None):
    logger.info("Loading Training set...")
    logger.info(model_name)
    train_iter, train_input_vocab, train_target_vocab = dataloader(
        train_data_path, batch_size=cfg.TRAIN.BATCH_SIZE, use_cuda=use_cuda)
    val_iters = {}
    for split_name, path in val_data_paths.items():
        val_iters[split_name], _, _ = dataloader(
            path,
            batch_size=cfg.VAL_BATCH_SIZE,
            use_cuda=use_cuda,
            input_vocab=train_input_vocab,
            target_vocab=train_target_vocab)

    pad_idx, sos_idx, eos_idx = train_target_vocab.stoi['<pad>'], train_target_vocab.stoi['<sos>'], \
                                train_target_vocab.stoi['<eos>']

    train_input_vocab_size, train_target_vocab_size = len(
        train_input_vocab.itos), len(train_target_vocab.itos)
    '''
    Input (command) [0]: batch_size x max_cmd_len       [1]: batch_size x 0 (len for each cmd)
    Situation: batch_size x grid x grid x feat_size
    Target (action) [0]: batch_size x max_action_len    [1]: batch_size x 0 (len for each action sequence)

    max_cmd_len = 6, max_action_len = 16
    '''
    logger.info("Done Loading Training set.")

    # if generate_vocabularies:
    #     training_set.save_vocabularies(input_vocab_path, target_vocab_path)
    #     logger.info("Saved vocabularies to {} for input and {} for target.".format(input_vocab_path, target_vocab_path))

    logger.info("Loading Dev. set...")

    # val_input_vocab_size, val_target_vocab_size = train_input_vocab_size, train_target_vocab_size

    # Shuffle the test set to make sure that if we only evaluate max_testing_examples we get a random part of the set.

    # val_set.shuffle_data()
    logger.info("Done Loading Dev. set.")

    model = GSCAN_model(pad_idx,
                        eos_idx,
                        train_input_vocab_size,
                        train_target_vocab_size,
                        is_baseline=is_baseline,
                        output_directory=os.path.join(os.getcwd(),
                                                      cfg.OUTPUT_DIRECTORY,
                                                      model_name))

    model = model.cuda() if use_cuda else model

    log_parameters(model)
    trainable_parameters = [
        parameter for parameter in model.parameters()
        if parameter.requires_grad
    ]
    optimizer = torch.optim.Adam(trainable_parameters,
                                 lr=cfg.TRAIN.SOLVER.LR,
                                 betas=(cfg.TRAIN.SOLVER.ADAM_BETA1,
                                        cfg.TRAIN.SOLVER.ADAM_BETA2))
    scheduler = LambdaLR(optimizer,
                         lr_lambda=lambda t: cfg.TRAIN.SOLVER.LR_DECAY**
                         (t / cfg.TRAIN.SOLVER.LR_DECAY_STEP))

    start_iteration = 1
    best_exact_match = 0

    if resume_from_file:
        assert os.path.isfile(
            resume_from_file), "No checkpoint found at {}".format(
                resume_from_file)
        logger.info(
            "Loading checkpoint from file at '{}'".format(resume_from_file))
        optimizer_state_dict = model.load_model(resume_from_file)
        optimizer.load_state_dict(optimizer_state_dict)
        start_iteration = model.trained_iterations
        logger.info("Loaded checkpoint '{}' (iter {})".format(
            resume_from_file, start_iteration))

    logger.info("Training starts..")
    training_iteration = start_iteration
    while training_iteration < cfg.TRAIN.MAX_EPOCH:  # iterations here actually means "epoch"

        # Shuffle the dataset and loop over it.
        # training_set.shuffle_data()
        num_batch = 0
        for x in train_iter:
            is_best = False
            model.train()
            target_scores, target_position_scores = model(
                x.input, x.situation, x.target)

            loss = model.get_loss(target_scores, x.target[0])

            target_loss = 0
            if cfg.AUXILIARY_TASK:
                target_loss = model.get_auxiliary_loss(target_position_scores,
                                                       x.target)
            loss += cfg.TRAIN.WEIGHT_TARGET_LOSS * target_loss

            # Backward pass and update model parameters.
            loss.backward()
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
            model.update_state(is_best=is_best)

            # Print current metrics.
            if num_batch % cfg.PRINT_EVERY == 0:
                accuracy, exact_match = model.get_metrics(
                    target_scores, x.target[0])
                if cfg.AUXILIARY_TASK:
                    auxiliary_accuracy_target = model.get_auxiliary_accuracy(
                        target_position_scores, x.target)
                else:
                    auxiliary_accuracy_target = 0.
                learning_rate = scheduler.get_lr()[0]
                logger.info(
                    "Iteration %08d, loss %8.4f, accuracy %5.2f, exact match %5.2f, learning_rate %.5f,"
                    " aux. accuracy target pos %5.2f" %
                    (training_iteration, loss, accuracy, exact_match,
                     learning_rate, auxiliary_accuracy_target))

            num_batch += 1

        if training_iteration % cfg.EVALUATE_EVERY == 0:
            with torch.no_grad():
                model.eval()
                logger.info("Evaluating..")
                test_exact_match = 0
                test_accuracy = 0
                try:
                    for split_name, val_iter in val_iters.items():
                        accuracy, exact_match, target_accuracy = evaluate(
                            val_iter,
                            model=model,
                            max_decoding_steps=30,
                            pad_idx=pad_idx,
                            sos_idx=sos_idx,
                            eos_idx=eos_idx,
                            max_examples_to_evaluate=None)
                        if split_name == 'dev':
                            test_exact_match = exact_match
                            test_accuracy = accuracy

                        logger.info(" %s Accuracy: %5.2f Exact Match: %5.2f "
                                    " Target Accuracy: %5.2f " %
                                    (split_name, accuracy, exact_match,
                                     target_accuracy))
                except:
                    print("Exception!")

                if test_exact_match > best_exact_match:
                    is_best = True
                    best_accuracy = test_accuracy
                    best_exact_match = test_exact_match
                    model.update_state(accuracy=test_accuracy,
                                       exact_match=test_exact_match,
                                       is_best=is_best)
                file_name = model_name + "checkpoint.{}th.tar".format(
                    str(training_iteration))
                # file_name = os.path.join(os.getcwd(), cfg.OUTPUT_DIRECTORY, model_name, file_name)
                if is_best:
                    logger.info("saving best model...")
                    model.save_checkpoint(
                        file_name=file_name,
                        is_best=is_best,
                        optimizer_state_dict=optimizer.state_dict())

        if training_iteration % cfg.SAVE_EVERY == 0:
            logger.info("forcing to save model every several epochs...")
            file_name = model_name + " checkpoint_force.{}th.tar".format(
                str(training_iteration))
            # file_name = os.path.join(os.getcwd(), cfg.OUTPUT_DIRECTORY, model_name, file_name)
            model.save_checkpoint(file_name=file_name,
                                  is_best=False,
                                  optimizer_state_dict=optimizer.state_dict())

        training_iteration += 1  # warning: iteratin represents epochs here
    logger.info("Finished training.")
Beispiel #20
0
def train(
        data_path: str,
        data_directory: str,
        generate_vocabularies: bool,
        input_vocab_path: str,
        target_vocab_path: str,
        embedding_dimension: int,
        num_encoder_layers: int,
        encoder_dropout_p: float,
        encoder_bidirectional: bool,
        training_batch_size: int,
        test_batch_size: int,
        max_decoding_steps: int,
        num_decoder_layers: int,
        decoder_dropout_p: float,
        cnn_kernel_size: int,
        cnn_dropout_p: float,
        cnn_hidden_num_channels: int,
        simple_situation_representation: bool,
        decoder_hidden_size: int,
        encoder_hidden_size: int,
        learning_rate: float,
        adam_beta_1: float,
        adam_beta_2: float,
        lr_decay: float,
        lr_decay_steps: int,
        resume_from_file: str,
        max_training_iterations: int,
        output_directory: str,
        print_every: int,
        evaluate_every: int,
        conditional_attention: bool,
        auxiliary_task: bool,
        weight_target_loss: float,
        attention_type: str,
        k: int,
        max_training_examples,
        max_testing_examples,
        # SeqGAN params begin
        pretrain_gen_path,
        pretrain_gen_epochs,
        pretrain_disc_path,
        pretrain_disc_epochs,
        rollout_trails,
        rollout_update_rate,
        disc_emb_dim,
        disc_hid_dim,
        load_tensors_from_path,
        # SeqGAN params end
        seed=42,
        **kwargs):
    device = torch.device("cpu")
    cfg = locals().copy()
    torch.manual_seed(seed)

    logger.info("Loading Training set...")

    training_set = GroundedScanDataset(
        data_path,
        data_directory,
        split="train",
        input_vocabulary_file=input_vocab_path,
        target_vocabulary_file=target_vocab_path,
        generate_vocabulary=generate_vocabularies,
        k=k)
    training_set.read_dataset(
        max_examples=max_training_examples,
        simple_situation_representation=simple_situation_representation,
        load_tensors_from_path=load_tensors_from_path
    )  # set this to False if no pickle file available

    logger.info("Done Loading Training set.")
    logger.info("  Loaded {} training examples.".format(
        training_set.num_examples))
    logger.info("  Input vocabulary size training set: {}".format(
        training_set.input_vocabulary_size))
    logger.info("  Most common input words: {}".format(
        training_set.input_vocabulary.most_common(5)))
    logger.info("  Output vocabulary size training set: {}".format(
        training_set.target_vocabulary_size))
    logger.info("  Most common target words: {}".format(
        training_set.target_vocabulary.most_common(5)))

    if generate_vocabularies:
        training_set.save_vocabularies(input_vocab_path, target_vocab_path)
        logger.info(
            "Saved vocabularies to {} for input and {} for target.".format(
                input_vocab_path, target_vocab_path))

    # logger.info("Loading Dev. set...")
    # test_set = GroundedScanDataset(data_path, data_directory, split="dev",
    #                                input_vocabulary_file=input_vocab_path,
    #                                target_vocabulary_file=target_vocab_path, generate_vocabulary=False, k=0)
    # test_set.read_dataset(max_examples=max_testing_examples,
    #                       simple_situation_representation=simple_situation_representation)
    #
    # # Shuffle the test set to make sure that if we only evaluate max_testing_examples we get a random part of the set.
    # test_set.shuffle_data()

    # logger.info("Done Loading Dev. set.")

    generator = Model(
        input_vocabulary_size=training_set.input_vocabulary_size,
        target_vocabulary_size=training_set.target_vocabulary_size,
        num_cnn_channels=training_set.image_channels,
        input_padding_idx=training_set.input_vocabulary.pad_idx,
        target_pad_idx=training_set.target_vocabulary.pad_idx,
        target_eos_idx=training_set.target_vocabulary.eos_idx,
        **cfg)
    total_vocabulary = set(
        list(training_set.input_vocabulary._word_to_idx.keys()) +
        list(training_set.target_vocabulary._word_to_idx.keys()))
    total_vocabulary_size = len(total_vocabulary)
    discriminator = Discriminator(embedding_dim=disc_emb_dim,
                                  hidden_dim=disc_hid_dim,
                                  vocab_size=total_vocabulary_size,
                                  max_seq_len=max_decoding_steps)

    generator = generator.cuda() if use_cuda else generator
    discriminator = discriminator.cuda() if use_cuda else discriminator
    rollout = Rollout(generator, rollout_update_rate)
    log_parameters(generator)
    trainable_parameters = [
        parameter for parameter in generator.parameters()
        if parameter.requires_grad
    ]
    optimizer = torch.optim.Adam(trainable_parameters,
                                 lr=learning_rate,
                                 betas=(adam_beta_1, adam_beta_2))
    scheduler = LambdaLR(optimizer,
                         lr_lambda=lambda t: lr_decay**(t / lr_decay_steps))

    # Load model and vocabularies if resuming.
    start_iteration = 1
    best_iteration = 1
    best_accuracy = 0
    best_exact_match = 0
    best_loss = float('inf')
    if resume_from_file:
        assert os.path.isfile(
            resume_from_file), "No checkpoint found at {}".format(
                resume_from_file)
        logger.info(
            "Loading checkpoint from file at '{}'".format(resume_from_file))
        optimizer_state_dict = generator.load_model(resume_from_file)
        optimizer.load_state_dict(optimizer_state_dict)
        start_iteration = generator.trained_iterations
        logger.info("Loaded checkpoint '{}' (iter {})".format(
            resume_from_file, start_iteration))

    if pretrain_gen_path is None:
        print('Pretraining generator with MLE...')
        pre_train_generator(training_set,
                            training_batch_size,
                            generator,
                            seed,
                            pretrain_gen_epochs,
                            name='pretrained_generator')
    else:
        print('Load pretrained generator weights')
        generator_weights = torch.load(pretrain_gen_path)
        generator.load_state_dict(generator_weights)

    if pretrain_disc_path is None:
        print('Pretraining Discriminator....')
        train_discriminator(training_set,
                            discriminator,
                            training_batch_size,
                            generator,
                            seed,
                            pretrain_disc_epochs,
                            name="pretrained_discriminator")
    else:
        print('Loading Discriminator....')
        discriminator_weights = torch.load(pretrain_disc_path)
        discriminator.load_state_dict(discriminator_weights)

    logger.info("Training starts..")
    training_iteration = start_iteration
    torch.autograd.set_detect_anomaly(True)
    while training_iteration < max_training_iterations:

        # Shuffle the dataset and loop over it.
        training_set.shuffle_data()

        for (input_batch, input_lengths, _, situation_batch, _, target_batch,
             target_lengths, agent_positions, target_positions) in \
                training_set.get_data_iterator(batch_size=training_batch_size):

            is_best = False
            generator.train()

            # Forward pass.
            samples = generator.sample(
                batch_size=training_batch_size,
                max_seq_len=max(target_lengths).astype(int),
                commands_input=input_batch,
                commands_lengths=input_lengths,
                situations_input=situation_batch,
                target_batch=target_batch,
                sos_idx=training_set.input_vocabulary.sos_idx,
                eos_idx=training_set.input_vocabulary.eos_idx)

            rewards = rollout.get_reward(samples, rollout_trails, input_batch,
                                         input_lengths, situation_batch,
                                         target_batch,
                                         training_set.input_vocabulary.sos_idx,
                                         training_set.input_vocabulary.eos_idx,
                                         discriminator)

            assert samples.shape == rewards.shape

            # calculate rewards
            rewards = torch.exp(rewards).contiguous().view((-1, ))
            if use_cuda:
                rewards = rewards.cuda()

            # get generator scores for sequence
            target_scores = generator.get_normalized_logits(
                commands_input=input_batch,
                commands_lengths=input_lengths,
                situations_input=situation_batch,
                samples=samples,
                sample_lengths=target_lengths,
                sos_idx=training_set.input_vocabulary.sos_idx)

            del samples

            # calculate loss on the generated sequence given the rewards
            loss = generator.get_gan_loss(target_scores, target_batch, rewards)

            del rewards

            # Backward pass and update model parameters.
            loss.backward()
            optimizer.step()
            scheduler.step(training_iteration)
            optimizer.zero_grad()
            generator.update_state(is_best=is_best)

            # Print current metrics.
            if training_iteration % print_every == 0:
                # accuracy, exact_match = generator.get_metrics(target_scores, target_batch)
                learning_rate = scheduler.get_lr()[0]
                logger.info("Iteration %08d, loss %8.4f, learning_rate %.5f," %
                            (training_iteration, loss, learning_rate))
                # logger.info("Iteration %08d, loss %8.4f, accuracy %5.2f, exact match %5.2f, learning_rate %.5f,"
                #             % (training_iteration, loss, accuracy, exact_match, learning_rate))
            del target_scores, target_batch

            # # Evaluate on test set.
            # if training_iteration % evaluate_every == 0:
            #     with torch.no_grad():
            #         generator.eval()
            #         logger.info("Evaluating..")
            #         accuracy, exact_match, target_accuracy = evaluate(
            #             test_set.get_data_iterator(batch_size=1), model=generator,
            #             max_decoding_steps=max_decoding_steps, pad_idx=test_set.target_vocabulary.pad_idx,
            #             sos_idx=test_set.target_vocabulary.sos_idx,
            #             eos_idx=test_set.target_vocabulary.eos_idx,
            #             max_examples_to_evaluate=kwargs["max_testing_examples"])
            #         logger.info("  Evaluation Accuracy: %5.2f Exact Match: %5.2f "
            #                     " Target Accuracy: %5.2f" % (accuracy, exact_match, target_accuracy))
            #         if exact_match > best_exact_match:
            #             is_best = True
            #             best_accuracy = accuracy
            #             best_exact_match = exact_match
            #             generator.update_state(accuracy=accuracy, exact_match=exact_match, is_best=is_best)
            #         file_name = "checkpoint.pth.tar".format(str(training_iteration))
            #         if is_best:
            #             generator.save_checkpoint(file_name=file_name, is_best=is_best,
            #                                       optimizer_state_dict=optimizer.state_dict())

            rollout.update_params()

            train_discriminator(training_set,
                                discriminator,
                                training_batch_size,
                                generator,
                                seed,
                                epochs=1,
                                name="training_discriminator")
            training_iteration += 1
            if training_iteration > max_training_iterations:
                break
            del loss

        torch.save(
            generator.state_dict(),
            '{}/{}'.format(output_directory,
                           'gen_{}_{}.ckpt'.format(training_iteration, seed)))
        torch.save(
            discriminator.state_dict(),
            '{}/{}'.format(output_directory,
                           'dis_{}_{}.ckpt'.format(training_iteration, seed)))

    logger.info("Finished training.")
def train_and_evaluate(model,
                       data_loader,
                       train_data,
                       val_data,
                       test_data,
                       optimizer,
                       metrics,
                       params,
                       model_dir,
                       data_encoder,
                       label_encoder,
                       restore_file=None,
                       save_model=True,
                       eval=True):
    from src.ner.utils import SummaryWriter, Label, plot

    # plotting tools
    train_summary_writer = SummaryWriter([*metrics] + ['loss'], name='train')
    val_summary_writer = SummaryWriter([*metrics] + ['loss'], name='val')
    test_summary_writer = SummaryWriter([*metrics] + ['loss'], name='test')
    writers = [train_summary_writer, val_summary_writer, test_summary_writer]
    labeller = Label(anchor_metric='f1_score',
                     anchor_writer='val')
    plots_dir = os.path.join(model_dir, 'plots')
    if not os.path.exists(plots_dir):
        os.makedirs(plots_dir)

    start_epoch = -1
    if restore_file is not None:
        logging.info("Restoring parameters from {}".format(restore_file))
        checkpoint = utils.load_checkpoint(restore_file, model, optimizer)
        start_epoch = checkpoint['epoch']

    # save the snapshot of parameters fro reproducibility
    utils.save_dict_to_json(params.dict, os.path.join(model_dir, 'train_snapshot.json'))

    # variable initialization
    best_val_score = 0.0
    patience = 0
    early_stopping_metric = 'f1_score'

    # set the Learning rate Scheduler
    lambda_lr = lambda epoch: 1 / (1 + (params.lr_decay_rate * epoch))
    lr_scheduler = LambdaLR(optimizer, lr_lambda=lambda_lr, last_epoch=start_epoch)

    # train over epochs
    for epoch in range(start_epoch + 1, params.num_epochs):
        lr_scheduler.step()
        # Run one epoch
        logging.info("Epoch {}/{}".format(epoch + 1, params.num_epochs))
        logging.info("Learning Rate : {}".format(lr_scheduler.get_lr()))

        # compute number of batches in one epoch (one full pass over the training set)
        # num_steps = (params.train_size + 1) // params.batch_size
        num_steps = (train_data['size'] + 1) // params.batch_size
        train_data_iterator = data_loader.batch_iterator(train_data, batch_size=params.batch_size, shuffle=True)
        train_metrics = train(model,
                              optimizer,
                              train_data_iterator,
                              metrics,
                              params,
                              num_steps,
                              data_encoder,
                              label_encoder)
        val_score = train_metrics[early_stopping_metric]
        is_best = val_score >= best_val_score
        train_summary_writer.update(train_metrics)

        if eval:
            # Evaluate for one epoch on validation set
            # num_steps = (params.val_size + 1) // params.batch_size
            num_steps = (val_data['size'] + 1) // params.batch_size
            val_data_iterator = data_loader.batch_iterator(val_data, batch_size=params.batch_size, shuffle=False)
            val_metrics = evaluate(model,
                                   val_data_iterator,
                                   metrics,
                                   num_steps,
                                   label_encoder,
                                   mode='val')

            val_score = val_metrics[early_stopping_metric]
            is_best = val_score >= best_val_score
            val_summary_writer.update(val_metrics)

            ### TEST
            # num_steps = (params.test_size + 1) // params.batch_size
            num_steps = (test_data['size'] + 1) // params.batch_size
            test_data_iterator = data_loader.batch_iterator(test_data, batch_size=params.batch_size, shuffle=False)
            test_metrics = evaluate(model,
                                    test_data_iterator,
                                    metrics,
                                    num_steps,
                                    label_encoder,
                                    mode='test')
            test_summary_writer.update(test_metrics)

        labeller.update(writers=writers)

        plot(writers=writers,
             plot_dir=plots_dir,
             save=True)

        # Save weights
        if save_model:
            utils.save_checkpoint({'epoch': epoch,
                                   'state_dict': model.state_dict(),
                                   'optim_dict': optimizer.state_dict()},
                                  is_best=is_best,
                                  checkpoint=model_dir)

            # save encoders only if they do not exist yet
            if not os.path.exists(os.path.join(model_dir, 'data_encoder.pkl')):
                utils.save_obj(data_encoder, os.path.join(model_dir, 'data_encoder.pkl'))
            if not os.path.exists(os.path.join(model_dir, 'label_encoder.pkl')):
                utils.save_obj(label_encoder, os.path.join(model_dir, 'label_encoder.pkl'))

        # If best_eval, best_save_path
        if is_best:
            patience = 0
            logging.info("- Found new best F1 score")
            best_val_score = val_score
            # Save best metrics in a json file in the model directory
            if eval:
                utils.save_dict_to_json(val_metrics, os.path.join(model_dir, 'plots', "metrics_val_best_weights.json"))
                utils.save_dict_to_json(test_metrics, os.path.join(model_dir, 'plots', "metrics_test_best_weights.json"))
            utils.save_dict_to_json(train_metrics, os.path.join(model_dir, 'plots', "metrics_train_best_weights.json"))
        else:
            if eval:
                patience += 1
                logging.info('current patience: {} ; max patience: {}'.format(patience, params.patience))
            if patience == params.patience:
                logging.info('patience reached. Exiting at epoch: {}'.format(epoch + 1))
                # Save latest metrics in a json file in the model directory before exiting
                if eval:
                    utils.save_dict_to_json(val_metrics, os.path.join(model_dir, 'plots', "metrics_val_last_weights.json"))
                    utils.save_dict_to_json(test_metrics,
                                            os.path.join(model_dir, 'plots', "metrics_test_last_weights.json"))
                utils.save_dict_to_json(train_metrics, os.path.join(model_dir, 'plots', "metrics_train_last_weights.json"))
                epoch = epoch - patience
                break

        # Save latest metrics in a json file in the model directory at end of epoch
        if eval:
            utils.save_dict_to_json(val_metrics, os.path.join(model_dir, 'plots', "metrics_val_last_weights.json"))
            utils.save_dict_to_json(test_metrics, os.path.join(model_dir, 'plots', "metrics_test_last_weights.json"))
        utils.save_dict_to_json(train_metrics, os.path.join(model_dir, 'plots', "metrics_train_last_weights.json"))
    return epoch
Beispiel #22
0
def train(model, tokenizer, train_data, valid_data, args, eos=False):
    model.train()

    train_dataset = TextDataset(train_data)
    train_dataloader = DataLoader(train_dataset, sampler=RandomSampler(train_dataset),
                                  batch_size=args.train_batch_size, num_workers=args.num_workers,
                                  collate_fn=lambda x: collate_fn(x, tokenizer, args.max_seq_length, eos=eos, tokenizer_type=args.tokenizer))

    valid_dataset = TextDataset(valid_data)
    valid_dataloader = DataLoader(valid_dataset, sampler=SequentialSampler(valid_dataset),
                                  batch_size=args.eval_batch_size, num_workers=args.num_workers,
                                  collate_fn=lambda x: collate_fn(x, tokenizer, args.max_seq_length, eos=eos, tokenizer_type=args.tokenizer))

    valid_noisy = [x['noisy'] for x in valid_data]
    valid_clean = [x['clean'] for x in valid_data]

    epochs = (args.max_steps - 1) // len(train_dataloader) + 1
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr,
                                 betas=eval(args.adam_betas), eps=args.eps,
                                 weight_decay=args.weight_decay)
    lr_lambda = lambda x: x / args.num_warmup_steps if x <= args.num_warmup_steps else (x / args.num_warmup_steps) ** -0.5
    scheduler = LambdaLR(optimizer, lr_lambda)

    step = 0
    best_val_gleu = -float("inf")
    meter = Meter()
    for epoch in range(1, epochs + 1):
        print("===EPOCH: ", epoch)
        for batch in train_dataloader:
            step += 1
            batch = tuple(t.to(args.device) for t in batch)
            loss, items = calc_loss(model, batch)
            meter.add(*items)

            loss.backward()
            if args.max_grad_norm > 0:
                nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
            optimizer.step()
            model.zero_grad()
            scheduler.step()

            if step % args.log_interval == 0:
                lr = scheduler.get_lr()[0]
                loss_sent, loss_token = meter.average()

                logger.info(f' [{step:5d}] lr {lr:.6f} | {meter.print_str(True)}')
                nsml.report(step=step, scope=locals(), summary=True,
                            train__lr=lr, train__loss_sent=loss_sent, train__token_ppl=math.exp(loss_token))
                meter.init()

            if step % args.eval_interval == 0:
                start_eval = time.time()
                (val_loss, val_loss_token), valid_str = evaluate(model, valid_dataloader, args)
                prediction = correct(model, tokenizer, valid_noisy, args, eos=eos, length_limit=0.1)
                val_em = em(prediction, valid_clean)
                cnt = 0
                for noisy, pred, clean in zip(valid_noisy, prediction, valid_clean):
                    print(f'[{noisy}], [{pred}], [{clean}]')
                    # 10개만 출력하기
                    cnt += 1
                    if cnt == 20:
                        break
                val_gleu = gleu(prediction, valid_clean)

                logger.info('-' * 89)
                logger.info(f' [{step:6d}] valid | {valid_str} | em {val_em:5.2f} | gleu {val_gleu:5.2f}')
                logger.info('-' * 89)
                nsml.report(step=step, scope=locals(), summary=True,
                            valid__loss_sent=val_loss, valid__token_ppl=math.exp(val_loss_token),
                            valid__em=val_em, valid__gleu=val_gleu)

                if val_gleu > best_val_gleu:
                    best_val_gleu = val_gleu
                    nsml.save("best")
                meter.start += time.time() - start_eval

            if step >= args.max_steps:
                break
        #nsml.save(epoch)
        if step >= args.max_steps:
            break
Beispiel #23
0
def train_and_valid_(net,
                     criterion,
                     optimizer,
                     train_loader,
                     valid_loader,
                     cfg,
                     is_lr_adjust=True,
                     is_lr_warmup=False):

    # ------------------配置信息------------------------------
    # 若检查点存在且容许使用检查点,则加载参数进行训练
    if os.path.exists(cfg.checkpoints) and cfg.use_checkpoints:
        # 加载权重信息
        net.load_state_dict(torch.load(cfg.checkpoints))
        print('加载权重信息...')

    # 配置学习率衰减器(默认是按epoch衰减);两种类型的学习率衰减
    if is_lr_adjust:
        # 按一定周期之后进行衰减<StepLR>
        lr_shcleduler_step = StepLR(optimizer=optimizer,
                                    step_size=cfg.lr_decay_step)
    elif is_lr_warmup:  # 若True,则开启学习率预热
        # 定义Lambda表达式 < LambdaLR >
        lr_lambda = lambda epoch: epoch / cfg.lr_warmup_step
        lr_shcleduler_warmup = LambdaLR(optimizer=optimizer,
                                        lr_lambda=lr_lambda)
        lr_shcleduler_warmup.step()

    # 获得记录日志信息的写入器
    writer = SummaryWriter(cfg.log_dir)

    # ------------------定义训练、验证子函数--------------------
    # 训练子函数
    def _train(train_loader, num_step):
        print('  training stage....')
        # 将网络结构调成训练模式;初始化梯度张量
        net.train()
        optimizer.zero_grad()
        # 定义准确率变量,损失值,批次数量,样本总数量
        train_acc = 0.0
        train_loss = 0.0
        num_batch = 0
        num_samples = 0

        # 进行网络的训练
        for index, data in enumerate(train_loader, start=0):
            # 获取每批次的训练数据、并将训练数据放入GPU中
            images, labels = data
            # print(images.size(), labels)
            images = images.to(cfg.device)
            labels = labels.to(cfg.device)

            # 推理输出网络预测值,并使用softmax使预测值满足0-1概率范围;计算损失函数值
            outputs = net(images)
            outputs = F.softmax(outputs, dim=1)
            loss = criterion(outputs, labels)

            # 计算每个预测值概率最大的索引(下标)
            preds = torch.argmax(outputs, dim=1)

            # 计算批次的准确率,预测值中预测正确的样本占总样本的比例
            # 统计准确率、损失值、批次数量
            acc = torch.sum(preds == labels).item()
            train_acc += acc
            train_loss += loss
            num_batch += 1
            num_samples += images.size(0)

            # 判断是否使用梯度累积技巧(显存少的时候),否则,进行正常的反向传播(计算梯度)和梯度下降优化操作
            if cfg.grad_accuml is True and cfg.batch_size < 128:
                # 累积损失,求累积损失的平均损失
                loss = loss / cfg.batch_accumulate_size
                loss.backward()
                # 满足一定批次要求则进行梯度参数更新,重置梯度张量
                if (index + 1) % cfg.batch_accumulate_size == 0:
                    optimizer.step()
                    optimizer.zero_grad()
            else:
                # 计算梯度、更新参数、重置梯度张量
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()

            # 输出一定次数的损失和精度情况
            if (index + 1) % cfg.print_rate == 0:
                # 输出损失值和精度值
                print('   batch:{}, batch_loss:{:.4f}, batch_acc:{:.4f}\n'.
                      format(index, loss, acc / images.size(0)))

            # 记录训练批次的损失和准确率
            # writer.add_scalar('Train/Loss', scalar_value=loss, global_step=index)  # 单个标签
            writer.add_scalars(main_tag='Train(batch)',
                               tag_scalar_dict={
                                   'batch_loss': loss,
                                   'batch_accuracy': acc / images.size(0)
                               },
                               global_step=num_step)
            # 更新全局步骤
            num_step += 1

        # 计算训练的准确率和损失值
        train_acc = train_acc / num_samples
        train_loss = train_loss / num_batch
        return train_acc, train_loss, num_step

    # 验证子函数
    def _valid(valid_loader):
        print('  valid stage...')
        # 将网络结构调成验证模式;所有样本的准确率、损失值;统计批次数量;
        net.eval()
        valid_acc = 0.0
        valid_loss = 0.0
        num_batch = 0
        num_samples = 0

        # 进行测试集的测试
        with torch.no_grad():  # 不使用梯度,减少内存占用
            for index, data in enumerate(valid_loader, start=0):
                images, labels = data
                # 将测试数据放入GPU上
                images, labels = images.to(cfg.device), labels.to(cfg.device)
                # 推理输出网络预测值,并使用softmax使预测值满足0-1概率范围
                outputs = net(images)
                outputs = F.softmax(outputs, dim=1)
                # 计算每个预测值概率最大的索引(下标);计算损失值
                pred = torch.argmax(outputs, dim=1)
                loss = criterion(outputs, labels)

                # 统计真实标签和预测标签的对应情况;计算损失
                valid_acc += torch.sum((pred == labels)).item()
                valid_loss += loss
                num_batch += 1
                num_samples += images.size(0)

        # 计算测试精度和损失值
        valid_acc = valid_acc / num_samples
        valid_loss = valid_loss / num_batch

        return valid_acc, valid_loss

    # ----------------------------开始周期训练--------------------------------
    # 定义训练开始时间、最好验证准确度(用于保存最好的模型)、统计训练步骤总数
    start_time = time.time()
    best_acc = 0.0
    num_step = 0

    # 开始周期训练
    for epoch in range(cfg.epochs):
        # 设定每周期开始时间点、周期信息
        epoch_start_time = time.time()
        print('Epoch {}/{}'.format(epoch, cfg.epochs - 1))
        print('-' * 20)

        # 训练
        train_acc, train_loss, num_step = _train(train_loader, num_step)
        # 验证
        valid_acc, valid_loss = _valid(valid_loader)

        # 调整学习率
        # 在前几周期内,进行学习率预热
        if is_lr_warmup is True and epoch < cfg.lr_warmup_step:
            lr_shcleduler_warmup.step()
            print('  epoch:{}/{}, learning rate warmup...{}'.format(
                epoch, cfg.lr_warmup_step - 1, lr_shcleduler_warmup.get_lr()))
        elif is_lr_adjust:  # 在经过一定学习率预热后,学习率恢复成初始的值。或则直接进行周期下降。
            lr_shcleduler_step.step()

        # 输出每周期的训练、验证的平均损失值、准确率
        epoch_time = time.time() - epoch_start_time
        print('   epoch:{}/{}, time:{:.0f}m {:.0f}s'.format(
            epoch, cfg.epochs, epoch_time // 60, epoch_time % 60))
        print(
            '   train_loss:{:.4f}, train_acc:{:.4f}\n   valid_loss:{:.4f}, valid_acc:{:.4f}'
            .format(train_loss, train_acc, valid_loss, valid_acc))

        # 记录测试结果
        writer.add_scalars(main_tag='Train(epoch)',
                           tag_scalar_dict={
                               'train_loss': train_loss,
                               'train_acc': train_acc,
                               'valid_loss': valid_loss,
                               'valid_acc': valid_acc
                           },
                           global_step=epoch)

        # 选出最好的模型参数
        if valid_acc > best_acc:
            # 更新最好精度、保存最好的模型参数
            best_acc = valid_acc
            torch.save(net.state_dict(), cfg.checkpoints)
            print('  epoch:{}, update model...'.format(epoch))
        print()

    # 训练结束时间、输出最好的精度
    end_time = time.time() - start_time
    print('Training complete in {:.0f}m {:.0f}s'.format(
        end_time // 60, end_time % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

    # 关闭writer
    writer.close()
Beispiel #24
0
    def train(self) -> None:
        r"""Main method for training PPO.

        Returns:
            None
        """
        global lr_lambda
        logger.info(f"config: {self.config}")
        random.seed(self.config.SEED)
        np.random.seed(self.config.SEED)
        torch.manual_seed(self.config.SEED)

        self.envs = construct_envs(self.config,
                                   get_env_class(self.config.ENV_NAME),
                                   auto_reset_done=False)

        ppo_cfg = self.config.RL.PPO
        self.device = (torch.device("cuda", self.config.TORCH_GPU_ID)
                       if torch.cuda.is_available() else torch.device("cpu"))
        if not os.path.isdir(self.config.CHECKPOINT_FOLDER):
            os.makedirs(self.config.CHECKPOINT_FOLDER)
        self._setup_actor_critic_agent(ppo_cfg)
        logger.info("agent number of parameters: {}".format(
            sum(param.numel() for param in self.agent.parameters())))

        rollouts = RolloutStorage(ppo_cfg.num_steps, self.envs.num_envs,
                                  self.envs.observation_spaces[0],
                                  self.envs.action_spaces[0],
                                  ppo_cfg.hidden_size)
        rollouts.to(self.device)

        observations = self.envs.reset()
        batch = batch_obs(observations)

        for sensor in rollouts.observations:
            rollouts.observations[sensor][0].copy_(batch[sensor])

        # batch and observations may contain shared PyTorch CUDA
        # tensors.  We must explicitly clear them here otherwise
        # they will be kept in memory for the entire duration of training!
        batch = None
        observations = None

        # episode_rewards and episode_counts accumulates over the entire training course
        episode_rewards = torch.zeros(self.envs.num_envs, 1)
        episode_spls = torch.zeros(self.envs.num_envs, 1)
        episode_steps = torch.zeros(self.envs.num_envs, 1)
        episode_counts = torch.zeros(self.envs.num_envs, 1)
        episode_distances = torch.zeros(self.envs.num_envs, 1)
        current_episode_reward = torch.zeros(self.envs.num_envs, 1)
        current_episode_step = torch.zeros(self.envs.num_envs, 1)
        window_episode_reward = deque(maxlen=ppo_cfg.reward_window_size)
        window_episode_spl = deque(maxlen=ppo_cfg.reward_window_size)
        window_episode_step = deque(maxlen=ppo_cfg.reward_window_size)
        window_episode_counts = deque(maxlen=ppo_cfg.reward_window_size)
        window_episode_distances = deque(maxlen=ppo_cfg.reward_window_size)

        t_start = time.time()
        env_time = 0
        pth_time = 0
        count_steps = 0
        count_checkpoints = 0
        start_update = 0
        prev_time = 0

        if ppo_cfg.use_linear_lr_decay:

            def lr_lambda(x):
                return linear_decay(x, self.config.NUM_UPDATES)
        elif ppo_cfg.use_exponential_lr_decay:

            def lr_lambda(x):
                return exponential_decay(x, self.config.NUM_UPDATES,
                                         ppo_cfg.exp_decay_lambda)
        else:

            def lr_lambda(x):
                return 1

        lr_scheduler = LambdaLR(optimizer=self.agent.optimizer,
                                lr_lambda=lr_lambda)

        with TensorboardWriter(self.config.TENSORBOARD_DIR,
                               flush_secs=self.flush_secs) as writer:
            for update in range(start_update, self.config.NUM_UPDATES):
                if ppo_cfg.use_linear_lr_decay or ppo_cfg.use_exponential_lr_decay:
                    lr_scheduler.step()

                if ppo_cfg.use_linear_clip_decay:
                    self.agent.clip_param = ppo_cfg.clip_param * linear_decay(
                        update, self.config.NUM_UPDATES)

                for step in range(ppo_cfg.num_steps):
                    delta_pth_time, delta_env_time, delta_steps = self._collect_rollout_step(
                        rollouts, current_episode_reward, current_episode_step,
                        episode_rewards, episode_spls, episode_counts,
                        episode_steps, episode_distances)
                    pth_time += delta_pth_time
                    env_time += delta_env_time
                    count_steps += delta_steps

                delta_pth_time, value_loss, action_loss, dist_entropy = self._update_agent(
                    ppo_cfg, rollouts)
                pth_time += delta_pth_time

                window_episode_reward.append(episode_rewards.clone())
                window_episode_spl.append(episode_spls.clone())
                window_episode_step.append(episode_steps.clone())
                window_episode_counts.append(episode_counts.clone())
                window_episode_distances.append(episode_distances.clone())

                losses = [value_loss, action_loss, dist_entropy]
                stats = zip(
                    ["count", "reward", "step", 'spl', 'distance'],
                    [
                        window_episode_counts, window_episode_reward,
                        window_episode_step, window_episode_spl,
                        window_episode_distances
                    ],
                )
                deltas = {
                    k:
                    ((v[-1] -
                      v[0]).sum().item() if len(v) > 1 else v[0].sum().item())
                    for k, v in stats
                }
                deltas["count"] = max(deltas["count"], 1.0)

                # this reward is averaged over all the episodes happened during window_size updates
                # approximately number of steps is window_size * num_steps
                writer.add_scalar("Environment/Reward",
                                  deltas["reward"] / deltas["count"],
                                  count_steps)

                writer.add_scalar("Environment/SPL",
                                  deltas["spl"] / deltas["count"], count_steps)

                logging.debug('Number of steps: {}'.format(deltas["step"] /
                                                           deltas["count"]))
                writer.add_scalar("Environment/Episode_length",
                                  deltas["step"] / deltas["count"],
                                  count_steps)

                writer.add_scalar("Environment/Distance_to_goal",
                                  deltas["distance"] / deltas["count"],
                                  count_steps)

                # writer.add_scalars(
                #     "losses",
                #     {k: l for l, k in zip(losses, ["value", "policy"])},
                #     count_steps,
                # )

                writer.add_scalar('Policy/Value_Loss', value_loss, count_steps)
                writer.add_scalar('Policy/Action_Loss', action_loss,
                                  count_steps)
                writer.add_scalar('Policy/Entropy', dist_entropy, count_steps)
                writer.add_scalar('Policy/Learning_Rate',
                                  lr_scheduler.get_lr()[0], count_steps)

                # log stats
                if update > 0 and update % self.config.LOG_INTERVAL == 0:
                    logger.info("update: {}\tfps: {:.3f}\t".format(
                        update,
                        count_steps / ((time.time() - t_start) + prev_time)))

                    logger.info(
                        "update: {}\tenv-time: {:.3f}s\tpth-time: {:.3f}s\t"
                        "frames: {}".format(update, env_time, pth_time,
                                            count_steps))

                    window_rewards = (window_episode_reward[-1] -
                                      window_episode_reward[0]).sum()
                    window_counts = (window_episode_counts[-1] -
                                     window_episode_counts[0]).sum()

                    if window_counts > 0:
                        logger.info(
                            "Average window size {} reward: {:3f}".format(
                                len(window_episode_reward),
                                (window_rewards / window_counts).item(),
                            ))
                    else:
                        logger.info("No episodes finish in current window")

                # checkpoint model
                if update % self.config.CHECKPOINT_INTERVAL == 0:
                    self.save_checkpoint(f"ckpt.{count_checkpoints}.pth")
                    count_checkpoints += 1

            self.envs.close()
Beispiel #25
0
def train():
    train_data = ACNet_data.SUNRGBD(transform=transforms.Compose([ACNet_data.scaleNorm(),
                                                                   ACNet_data.RandomScale((1.0, 1.4)),
                                                                   ACNet_data.RandomHSV((0.9, 1.1),
                                                                                         (0.9, 1.1),
                                                                                         (25, 25)),
                                                                   ACNet_data.RandomCrop(image_h, image_w),
                                                                   ACNet_data.RandomFlip(),
                                                                   ACNet_data.ToTensor(),
                                                                   ACNet_data.Normalize()]),
                                     phase_train=True,
                                     data_dir=args.data_dir)
    train_loader = DataLoader(train_data, batch_size=args.batch_size, shuffle=True,
                              num_workers=args.workers, pin_memory=False)

    num_train = len(train_data)

    if args.last_ckpt:
        model = ACNet_models_V1.ACNet(num_class=40, pretrained=False)
    else:
        model = ACNet_models_V1.ACNet(num_class=40, pretrained=True)
    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        model = nn.DataParallel(model)
    CEL_weighted = utils.CrossEntropyLoss2d(weight=nyuv2_frq)
    model.train()
    model.to(device)
    CEL_weighted.to(device)
    optimizer = torch.optim.SGD(model.parameters(), lr=args.lr,
                                momentum=args.momentum, weight_decay=args.weight_decay)

    global_step = 0

    if args.last_ckpt:
        global_step, args.start_epoch = load_ckpt(model, optimizer, args.last_ckpt, device)

    lr_decay_lambda = lambda epoch: args.lr_decay_rate ** (epoch // args.lr_epoch_per_decay)
    scheduler = LambdaLR(optimizer, lr_lambda=lr_decay_lambda)

    writer = SummaryWriter(args.summary_dir)

    for epoch in range(int(args.start_epoch), args.epochs):

        scheduler.step(epoch)
        local_count = 0
        last_count = 0
        end_time = time.time()
        if epoch % args.save_epoch_freq == 0 and epoch != args.start_epoch:
            save_ckpt(args.ckpt_dir, model, optimizer, global_step, epoch,
                      local_count, num_train)

        for batch_idx, sample in enumerate(train_loader):

            image = sample['image'].to(device)
            depth = sample['depth'].to(device)
            target_scales = [sample[s].to(device) for s in ['label', 'label2', 'label3', 'label4', 'label5']]
            optimizer.zero_grad()
            pred_scales = model(image, depth, args.checkpoint)
            loss = CEL_weighted(pred_scales, target_scales)
            loss.backward()
            optimizer.step()
            local_count += image.data.shape[0]
            global_step += 1
            if global_step % args.print_freq == 0 or global_step == 1:

                time_inter = time.time() - end_time
                count_inter = local_count - last_count
                print_log(global_step, epoch, local_count, count_inter,
                          num_train, loss, time_inter)
                end_time = time.time()

                for name, param in model.named_parameters():
                    writer.add_histogram(name, param.clone().cpu().data.numpy(), global_step, bins='doane')
                grid_image = make_grid(image[:3].clone().cpu().data, 3, normalize=True)
                writer.add_image('image', grid_image, global_step)
                grid_image = make_grid(depth[:3].clone().cpu().data, 3, normalize=True)
                writer.add_image('depth', grid_image, global_step)
                grid_image = make_grid(utils.color_label(torch.max(pred_scales[0][:3], 1)[1] + 1), 3, normalize=False,
                                       range=(0, 255))
                writer.add_image('Predicted label', grid_image, global_step)
                grid_image = make_grid(utils.color_label(target_scales[0][:3]), 3, normalize=False, range=(0, 255))
                writer.add_image('Groundtruth label', grid_image, global_step)
                writer.add_scalar('CrossEntropyLoss', loss.data, global_step=global_step)
                writer.add_scalar('Learning rate', scheduler.get_lr()[0], global_step=global_step)
                last_count = local_count

    save_ckpt(args.ckpt_dir, model, optimizer, global_step, args.epochs,
              0, num_train)

    print("Training completed ")
def main(args):
    logger = CompleteLogger(args.log, args.phase)
    print(args)

    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')

    cudnn.benchmark = True

    # Data loading code
    train_transform = T.Compose([
        T.RandomResizedCrop(size=args.train_size,
                            ratio=args.resize_ratio,
                            scale=(0.5, 1.)),
        T.RandomHorizontalFlip(),
        T.ToTensor(),
        T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    source_dataset = datasets.__dict__[args.source]
    train_source_dataset = source_dataset(root=args.source_root,
                                          transforms=train_transform)
    train_source_loader = DataLoader(train_source_dataset,
                                     batch_size=args.batch_size,
                                     shuffle=True,
                                     num_workers=args.workers,
                                     pin_memory=True,
                                     drop_last=True)

    target_dataset = datasets.__dict__[args.target]
    train_target_dataset = target_dataset(root=args.target_root,
                                          transforms=train_transform)
    train_target_loader = DataLoader(train_target_dataset,
                                     batch_size=args.batch_size,
                                     shuffle=True,
                                     num_workers=args.workers,
                                     pin_memory=True,
                                     drop_last=True)

    train_source_iter = ForeverDataIterator(train_source_loader)
    train_target_iter = ForeverDataIterator(train_target_loader)

    # define networks (both generators and discriminators)
    netG_S2T = cyclegan.generator.__dict__[args.netG](
        ngf=args.ngf, norm=args.norm, use_dropout=False).to(device)
    netG_T2S = cyclegan.generator.__dict__[args.netG](
        ngf=args.ngf, norm=args.norm, use_dropout=False).to(device)
    netD_S = cyclegan.discriminator.__dict__[args.netD](
        ndf=args.ndf, norm=args.norm).to(device)
    netD_T = cyclegan.discriminator.__dict__[args.netD](
        ndf=args.ndf, norm=args.norm).to(device)

    # create image buffer to store previously generated images
    fake_S_pool = ImagePool(args.pool_size)
    fake_T_pool = ImagePool(args.pool_size)

    # define optimizer and lr scheduler
    optimizer_G = Adam(itertools.chain(netG_S2T.parameters(),
                                       netG_T2S.parameters()),
                       lr=args.lr,
                       betas=(args.beta1, 0.999))
    optimizer_D = Adam(itertools.chain(netD_S.parameters(),
                                       netD_T.parameters()),
                       lr=args.lr,
                       betas=(args.beta1, 0.999))
    lr_decay_function = lambda epoch: 1.0 - max(0, epoch - args.epochs
                                                ) / float(args.epochs_decay)
    lr_scheduler_G = LambdaLR(optimizer_G, lr_lambda=lr_decay_function)
    lr_scheduler_D = LambdaLR(optimizer_D, lr_lambda=lr_decay_function)

    # optionally resume from a checkpoint
    if args.resume:
        print("Resume from", args.resume)
        checkpoint = torch.load(args.resume, map_location='cpu')
        netG_S2T.load_state_dict(checkpoint['netG_S2T'])
        netG_T2S.load_state_dict(checkpoint['netG_T2S'])
        netD_S.load_state_dict(checkpoint['netD_S'])
        netD_T.load_state_dict(checkpoint['netD_T'])
        optimizer_G.load_state_dict(checkpoint['optimizer_G'])
        optimizer_D.load_state_dict(checkpoint['optimizer_D'])
        lr_scheduler_G.load_state_dict(checkpoint['lr_scheduler_G'])
        lr_scheduler_D.load_state_dict(checkpoint['lr_scheduler_D'])
        args.start_epoch = checkpoint['epoch'] + 1

    if args.phase == 'test':
        transform = T.Compose([
            T.Resize(image_size=args.test_input_size),
            T.wrapper(cyclegan.transform.Translation)(netG_S2T, device),
        ])
        train_source_dataset.translate(transform, args.translated_root)
        return

    # define loss function
    criterion_gan = cyclegan.LeastSquaresGenerativeAdversarialLoss()
    criterion_cycle = nn.L1Loss()
    criterion_identity = nn.L1Loss()

    # define visualization function
    tensor_to_image = Compose(
        [Denormalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
         ToPILImage()])

    def visualize(image, name):
        """
        Args:
            image (tensor): image in shape 3 x H x W
            name: name of the saving image
        """
        tensor_to_image(image).save(
            logger.get_image_path("{}.png".format(name)))

    # start training
    for epoch in range(args.start_epoch, args.epochs + args.epochs_decay):
        logger.set_epoch(epoch)
        print(lr_scheduler_G.get_lr())

        # train for one epoch
        train(train_source_iter, train_target_iter, netG_S2T, netG_T2S, netD_S,
              netD_T, criterion_gan, criterion_cycle, criterion_identity,
              optimizer_G, optimizer_D, fake_S_pool, fake_T_pool, epoch,
              visualize, args)

        # update learning rates
        lr_scheduler_G.step()
        lr_scheduler_D.step()

        # save checkpoint
        torch.save(
            {
                'netG_S2T': netG_S2T.state_dict(),
                'netG_T2S': netG_T2S.state_dict(),
                'netD_S': netD_S.state_dict(),
                'netD_T': netD_T.state_dict(),
                'optimizer_G': optimizer_G.state_dict(),
                'optimizer_D': optimizer_D.state_dict(),
                'lr_scheduler_G': lr_scheduler_G.state_dict(),
                'lr_scheduler_D': lr_scheduler_D.state_dict(),
                'epoch': epoch,
                'args': args
            }, logger.get_checkpoint_path(epoch))

    if args.translated_root is not None:
        transform = T.Compose([
            T.Resize(image_size=args.test_input_size),
            T.wrapper(cyclegan.transform.Translation)(netG_S2T, device),
        ])
        train_source_dataset.translate(transform, args.translated_root)

    logger.close()
Beispiel #27
0
criterion = CrossEntropyLoss2d()

metrics = Metrics()
if store.metrics:
    metrics.load_state_dict(store.metrics)

if FAKE:
    print('STOP TRAINING')
    exit(0)

# LOOP
print(f'Starting ({now_str()})')
iter_count = len(data_set) // BATCH_SIZE
while epoch < first_epoch + EPOCH_COUNT:
    iter_metrics = Metrics()
    lr = scheduler.get_lr()[0]
    for i, (inputs, labels) in enumerate(data_loader):
        inputs = inputs.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs).to(device)
        loss = criterion(outputs, labels)
        coef = Coef.calc(outputs, labels)
        iter_metrics.append_loss(loss.item())
        iter_metrics.append_coef(coef)
        pp('epoch[{ep}]:{i}/{I} iou:{c.pjac:.4f} acc:{c.pdice:.4f} loss:{loss:.4f} lr:{lr:.4f} ({t})'.format(
            ep=epoch, i=i+1, I=iter_count, lr=lr, t=now_str(), loss=loss.item(), c=coef))
        loss.backward()
        optimizer.step()
    pp('epoch[{ep}]:Done. iou:{c.pjac:.4f} acc:{c.pdice:.4f} gsi:{c.gsensi:.4f} gsp:{c.gspec:.4f} tsi:{c.tsensi:.4f} tsp:{c.tspec:.4f} loss:{loss:.4f} lr:{lr:.4f} ({t})'.format(
        ep=epoch, t=now_str(), lr=lr, loss=iter_metrics.avg('losses'), c=iter_metrics.avg_coef()
Beispiel #28
0
    def train(self,
              model,
              train_loader,
              val_loader=None,
              num_epochs=10,
              log_nth=0,
              model_args={}):
        """
		Train a given model with the provided data.

		Inputs:
		- model: model object initialized from a torch.nn.Module
		- train_loader: train data in torch.utils.data.DataLoader
		- val_loader: val data in torch.utils.data.DataLoader
		- num_epochs: total number of training epochs
		- log_nth: log training accuracy and loss every nth iteration
		"""

        self.writer = tb.SummaryWriter(self.tb_dir)
        self.val_writer = tb.SummaryWriter(self.tb_val_dir)

        # filter out frcnn if this is added to the module
        parameters = [
            param for name, param in model.named_parameters()
            if 'frcnn' not in name
        ]
        optim = self.optim(parameters, **self.optim_args)

        if self.lr_scheduler_lambda:
            scheduler = LambdaLR(optim, lr_lambda=self.lr_scheduler_lambda)
        else:
            scheduler = None

        self._reset_histories()
        iter_per_epoch = len(train_loader)

        print('START TRAIN.')
        ############################################################################
        # TODO:                                                                    #
        # Write your own personal training method for our solver. In Each epoch    #
        # iter_per_epoch shuffled training batches are processed. The loss for     #
        # each batch is stored in self.train_loss_history. Every log_nth iteration #
        # the loss is logged. After one epoch the training accuracy of the last    #
        # mini batch is logged and stored in self.train_acc_history.               #
        # We validate at the end of each epoch, log the result and store the       #
        # accuracy of the entire validation set in self.val_acc_history.           #
        #
        # Your logging should like something like:                                 #
        #   ...                                                                    #
        #   [Iteration 700/4800] TRAIN loss: 1.452                                 #
        #   [Iteration 800/4800] TRAIN loss: 1.409                                 #
        #   [Iteration 900/4800] TRAIN loss: 1.374                                 #
        #   [Epoch 1/5] TRAIN acc/loss: 0.560/1.374                                #
        #   [Epoch 1/5] VAL   acc/loss: 0.539/1.310                                #
        #   ...                                                                    #
        ############################################################################

        for epoch in range(num_epochs):
            # TRAINING
            if scheduler:
                scheduler.step()
                print("[*] New learning rate(s): {}".format(
                    scheduler.get_lr()))

            now = time.time()

            for i, batch in enumerate(train_loader, 1):
                #inputs, labels = Variable(batch[0]), Variable(batch[1])

                optim.zero_grad()
                losses = model.sum_losses(batch, **model_args)
                losses['total_loss'].backward()
                optim.step()

                for k, v in losses.items():
                    if k not in self._losses.keys():
                        self._losses[k] = []
                    self._losses[k].append(v.data.cpu().numpy())

                if log_nth and i % log_nth == 0:
                    next_now = time.time()
                    print('[Iteration %d/%d] %.3f s/it' %
                          (i + epoch * iter_per_epoch,
                           iter_per_epoch * num_epochs,
                           (next_now - now) / log_nth))
                    now = next_now

                    for k, v in self._losses.items():
                        last_log_nth_losses = self._losses[k][-log_nth:]
                        train_loss = np.mean(last_log_nth_losses)
                        print('%s: %.3f' % (k, train_loss))
                        self.writer.add_scalar(k, train_loss,
                                               i + epoch * iter_per_epoch)

            # VALIDATION
            if val_loader and log_nth:
                model.eval()
                for i, batch in enumerate(val_loader):

                    losses = model.sum_losses(batch, **model_args)

                    for k, v in losses.items():
                        if k not in self._val_losses.keys():
                            self._val_losses[k] = []
                        self._val_losses[k].append(v.data.cpu().numpy())

                    if i >= log_nth:
                        break

                model.train()
                for k, v in self._losses.items():
                    last_log_nth_losses = self._val_losses[k][-log_nth:]
                    val_loss = np.mean(last_log_nth_losses)
                    self.val_writer.add_scalar(k, val_loss,
                                               (epoch + 1) * iter_per_epoch)

                #blobs_val = data_layer_val.forward()
                #tracks_val = model.val_predict(blobs_val)
                #im = plot_tracks(blobs_val, tracks_val)
                #self.val_writer.add_image('val_tracks', im, (epoch+1) * iter_per_epoch)

            self.snapshot(model, (epoch + 1) * iter_per_epoch)

            self._reset_histories()

        self.writer.close()
        self.val_writer.close()

        ############################################################################
        #                             END OF YOUR CODE                             #
        ############################################################################
        print('FINISH.')
    def train(self) -> None:
        r"""Main method for DD-PPO.

        Returns:
            None
        """
        self.local_rank, tcp_store = init_distrib_slurm(
            self.config.RL.DDPPO.distrib_backend)
        add_signal_handlers()

        # Stores the number of workers that have finished their rollout
        num_rollouts_done_store = distrib.PrefixStore("rollout_tracker",
                                                      tcp_store)
        num_rollouts_done_store.set("num_done", "0")

        self.world_rank = distrib.get_rank()
        self.world_size = distrib.get_world_size()

        self.config.defrost()
        self.config.TORCH_GPU_ID = self.local_rank
        self.config.SIMULATOR_GPU_ID = self.local_rank
        # Multiply by the number of simulators to make sure they also get unique seeds
        self.config.TASK_CONFIG.SEED += (self.world_rank *
                                         self.config.NUM_PROCESSES)
        self.config.freeze()

        random.seed(self.config.TASK_CONFIG.SEED)
        np.random.seed(self.config.TASK_CONFIG.SEED)
        torch.manual_seed(self.config.TASK_CONFIG.SEED)

        if torch.cuda.is_available():
            self.device = torch.device("cuda", self.local_rank)
            torch.cuda.set_device(self.device)
        else:
            self.device = torch.device("cpu")

        self.envs = construct_envs(self.config,
                                   get_env_class(self.config.ENV_NAME))

        ppo_cfg = self.config.RL.PPO
        if (not os.path.isdir(self.config.CHECKPOINT_FOLDER)
                and self.world_rank == 0):
            os.makedirs(self.config.CHECKPOINT_FOLDER)

        self._setup_actor_critic_agent(ppo_cfg)
        self.agent.init_distributed(find_unused_params=True)
        if ppo_cfg.use_belief_predictor and ppo_cfg.BELIEF_PREDICTOR.online_training:
            self.belief_predictor.init_distributed(find_unused_params=True)

        if self.world_rank == 0:
            logger.info("agent number of trainable parameters: {}".format(
                sum(param.numel() for param in self.agent.parameters()
                    if param.requires_grad)))
            if ppo_cfg.use_belief_predictor:
                logger.info(
                    "belief predictor number of trainable parameters: {}".
                    format(
                        sum(param.numel()
                            for param in self.belief_predictor.parameters()
                            if param.requires_grad)))
            logger.info(f"config: {self.config}")

        observations = self.envs.reset()
        batch = batch_obs(observations, device=self.device)

        obs_space = self.envs.observation_spaces[0]
        if ppo_cfg.use_external_memory:
            memory_dim = self.actor_critic.net.memory_dim
        else:
            memory_dim = None

        rollouts = RolloutStorage(
            ppo_cfg.num_steps,
            self.envs.num_envs,
            obs_space,
            self.action_space,
            ppo_cfg.hidden_size,
            ppo_cfg.use_external_memory,
            ppo_cfg.SCENE_MEMORY_TRANSFORMER.memory_size + ppo_cfg.num_steps,
            ppo_cfg.SCENE_MEMORY_TRANSFORMER.memory_size,
            memory_dim,
            num_recurrent_layers=self.actor_critic.net.num_recurrent_layers,
        )
        rollouts.to(self.device)

        if self.config.RL.PPO.use_belief_predictor:
            self.belief_predictor.update(batch, None)

        for sensor in rollouts.observations:
            rollouts.observations[sensor][0].copy_(batch[sensor])

        # batch and observations may contain shared PyTorch CUDA
        # tensors.  We must explicitly clear them here otherwise
        # they will be kept in memory for the entire duration of training!
        batch = None
        observations = None

        current_episode_reward = torch.zeros(self.envs.num_envs,
                                             1,
                                             device=self.device)
        running_episode_stats = dict(
            count=torch.zeros(self.envs.num_envs, 1, device=self.device),
            reward=torch.zeros(self.envs.num_envs, 1, device=self.device),
        )
        window_episode_stats = defaultdict(
            lambda: deque(maxlen=ppo_cfg.reward_window_size))

        t_start = time.time()
        env_time = 0
        pth_time = 0
        count_steps = 0
        count_checkpoints = 0
        start_update = 0
        prev_time = 0

        lr_scheduler = LambdaLR(
            optimizer=self.agent.optimizer,
            lr_lambda=lambda x: linear_decay(x, self.config.NUM_UPDATES),
        )

        # Try to resume at previous checkpoint (independent of interrupted states)
        count_steps_start, count_checkpoints, start_update = self.try_to_resume_checkpoint(
        )
        count_steps = count_steps_start

        interrupted_state = load_interrupted_state()
        if interrupted_state is not None:
            self.agent.load_state_dict(interrupted_state["state_dict"])
            if self.config.RL.PPO.use_belief_predictor:
                self.belief_predictor.load_state_dict(
                    interrupted_state["belief_predictor"])
            self.agent.optimizer.load_state_dict(
                interrupted_state["optim_state"])
            lr_scheduler.load_state_dict(interrupted_state["lr_sched_state"])

            requeue_stats = interrupted_state["requeue_stats"]
            env_time = requeue_stats["env_time"]
            pth_time = requeue_stats["pth_time"]
            count_steps = requeue_stats["count_steps"]
            count_checkpoints = requeue_stats["count_checkpoints"]
            start_update = requeue_stats["start_update"]
            prev_time = requeue_stats["prev_time"]

        with (TensorboardWriter(self.config.TENSORBOARD_DIR,
                                flush_secs=self.flush_secs)
              if self.world_rank == 0 else contextlib.suppress()) as writer:
            for update in range(start_update, self.config.NUM_UPDATES):
                if ppo_cfg.use_linear_lr_decay:
                    lr_scheduler.step()

                if ppo_cfg.use_linear_clip_decay:
                    self.agent.clip_param = ppo_cfg.clip_param * linear_decay(
                        update, self.config.NUM_UPDATES)

                if EXIT.is_set():
                    self.envs.close()

                    if REQUEUE.is_set() and self.world_rank == 0:
                        requeue_stats = dict(
                            env_time=env_time,
                            pth_time=pth_time,
                            count_steps=count_steps,
                            count_checkpoints=count_checkpoints,
                            start_update=update,
                            prev_time=(time.time() - t_start) + prev_time,
                        )
                        state_dict = dict(
                            state_dict=self.agent.state_dict(),
                            optim_state=self.agent.optimizer.state_dict(),
                            lr_sched_state=lr_scheduler.state_dict(),
                            config=self.config,
                            requeue_stats=requeue_stats,
                        )
                        if self.config.RL.PPO.use_belief_predictor:
                            state_dict[
                                'belief_predictor'] = self.belief_predictor.state_dict(
                                )
                        save_interrupted_state(state_dict)

                    requeue_job()
                    return

                count_steps_delta = 0
                self.agent.eval()
                if self.config.RL.PPO.use_belief_predictor:
                    self.belief_predictor.eval()
                for step in range(ppo_cfg.num_steps):

                    (
                        delta_pth_time,
                        delta_env_time,
                        delta_steps,
                    ) = self._collect_rollout_step(rollouts,
                                                   current_episode_reward,
                                                   running_episode_stats)
                    pth_time += delta_pth_time
                    env_time += delta_env_time
                    count_steps_delta += delta_steps

                    # This is where the preemption of workers happens.  If a
                    # worker detects it will be a straggler, it preempts itself!
                    if (step >=
                            ppo_cfg.num_steps * self.SHORT_ROLLOUT_THRESHOLD
                        ) and int(num_rollouts_done_store.get("num_done")) > (
                            self.config.RL.DDPPO.sync_frac * self.world_size):
                        break

                num_rollouts_done_store.add("num_done", 1)

                self.agent.train()
                if self.config.RL.PPO.use_belief_predictor:
                    self.belief_predictor.train()
                    self.belief_predictor.set_eval_encoders()
                if self._static_smt_encoder:
                    self.actor_critic.net.set_eval_encoders()

                if ppo_cfg.use_belief_predictor and ppo_cfg.BELIEF_PREDICTOR.online_training:
                    location_predictor_loss, prediction_accuracy = self.train_belief_predictor(
                        rollouts)
                else:
                    location_predictor_loss = 0
                    prediction_accuracy = 0
                (
                    delta_pth_time,
                    value_loss,
                    action_loss,
                    dist_entropy,
                ) = self._update_agent(ppo_cfg, rollouts)
                pth_time += delta_pth_time

                stats_ordering = list(sorted(running_episode_stats.keys()))
                stats = torch.stack(
                    [running_episode_stats[k] for k in stats_ordering], 0)
                distrib.all_reduce(stats)

                for i, k in enumerate(stats_ordering):
                    window_episode_stats[k].append(stats[i].clone())

                stats = torch.tensor(
                    [
                        value_loss, action_loss, dist_entropy,
                        location_predictor_loss, prediction_accuracy,
                        count_steps_delta
                    ],
                    device=self.device,
                )
                distrib.all_reduce(stats)
                count_steps += stats[5].item()

                if self.world_rank == 0:
                    num_rollouts_done_store.set("num_done", "0")

                    losses = [
                        stats[0].item() / self.world_size,
                        stats[1].item() / self.world_size,
                        stats[2].item() / self.world_size,
                        stats[3].item() / self.world_size,
                        stats[4].item() / self.world_size,
                    ]
                    deltas = {
                        k: ((v[-1] - v[0]).sum().item()
                            if len(v) > 1 else v[0].sum().item())
                        for k, v in window_episode_stats.items()
                    }
                    deltas["count"] = max(deltas["count"], 1.0)

                    writer.add_scalar("Metrics/reward",
                                      deltas["reward"] / deltas["count"],
                                      count_steps)

                    # Check to see if there are any metrics
                    # that haven't been logged yet
                    metrics = {
                        k: v / deltas["count"]
                        for k, v in deltas.items()
                        if k not in {"reward", "count"}
                    }
                    if len(metrics) > 0:
                        for metric, value in metrics.items():
                            writer.add_scalar(f"Metrics/{metric}", value,
                                              count_steps)

                    writer.add_scalar("Policy/value_loss", losses[0],
                                      count_steps)
                    writer.add_scalar("Policy/policy_loss", losses[1],
                                      count_steps)
                    writer.add_scalar("Policy/entropy_loss", losses[2],
                                      count_steps)
                    writer.add_scalar("Policy/predictor_loss", losses[3],
                                      count_steps)
                    writer.add_scalar("Policy/predictor_accuracy", losses[4],
                                      count_steps)
                    writer.add_scalar('Policy/learning_rate',
                                      lr_scheduler.get_lr()[0], count_steps)

                    # log stats
                    if update > 0 and update % self.config.LOG_INTERVAL == 0:
                        logger.info("update: {}\tfps: {:.3f}\t".format(
                            update,
                            (count_steps - count_steps_start) /
                            ((time.time() - t_start) + prev_time),
                        ))

                        logger.info(
                            "update: {}\tenv-time: {:.3f}s\tpth-time: {:.3f}s\t"
                            "frames: {}".format(update, env_time, pth_time,
                                                count_steps))
                        logger.info("Average window size: {}  {}".format(
                            len(window_episode_stats["count"]),
                            "  ".join(
                                "{}: {:.3f}".format(k, v / deltas["count"])
                                for k, v in deltas.items() if k != "count"),
                        ))

                    # checkpoint model
                    if update % self.config.CHECKPOINT_INTERVAL == 0:
                        self.save_checkpoint(
                            f"ckpt.{count_checkpoints}.pth",
                            dict(step=count_steps),
                        )
                        count_checkpoints += 1

            self.envs.close()
Beispiel #30
0
def main(args: argparse.Namespace):
    logger = CompleteLogger(args.log, args.phase)
    print(args)

    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')

    cudnn.benchmark = True

    # Data loading code
    source_dataset = datasets.__dict__[args.source]
    train_source_dataset = source_dataset(
        root=args.source_root,
        transforms=T.Compose([
            T.RandomResizedCrop(size=args.train_size,
                                ratio=args.resize_ratio,
                                scale=(0.5, 1.)),
            T.ColorJitter(brightness=0.3, contrast=0.3),
            T.RandomHorizontalFlip(),
            T.NormalizeAndTranspose(),
        ]),
    )
    train_source_loader = DataLoader(train_source_dataset,
                                     batch_size=args.batch_size,
                                     shuffle=True,
                                     num_workers=args.workers,
                                     pin_memory=True,
                                     drop_last=True)

    target_dataset = datasets.__dict__[args.target]
    train_target_dataset = target_dataset(
        root=args.target_root,
        transforms=T.Compose([
            T.RandomResizedCrop(size=args.train_size,
                                ratio=(2., 2.),
                                scale=(0.5, 1.)),
            T.RandomHorizontalFlip(),
            T.NormalizeAndTranspose(),
        ]),
    )
    train_target_loader = DataLoader(train_target_dataset,
                                     batch_size=args.batch_size,
                                     shuffle=True,
                                     num_workers=args.workers,
                                     pin_memory=True,
                                     drop_last=True)
    val_target_dataset = target_dataset(
        root=args.target_root,
        split='val',
        transforms=T.Compose([
            T.Resize(image_size=args.test_input_size,
                     label_size=args.test_output_size),
            T.NormalizeAndTranspose(),
        ]),
    )
    val_target_loader = DataLoader(val_target_dataset,
                                   batch_size=1,
                                   shuffle=False,
                                   pin_memory=True)

    train_source_iter = ForeverDataIterator(train_source_loader)
    train_target_iter = ForeverDataIterator(train_target_loader)

    # create model
    num_classes = train_source_dataset.num_classes
    model = models.__dict__[args.arch](num_classes=num_classes).to(device)
    discriminator = Discriminator(num_classes=num_classes).to(device)

    # define optimizer and lr scheduler
    optimizer = SGD(model.get_parameters(),
                    lr=args.lr,
                    momentum=args.momentum,
                    weight_decay=args.weight_decay)
    optimizer_d = Adam(discriminator.parameters(),
                       lr=args.lr_d,
                       betas=(0.9, 0.99))
    lr_scheduler = LambdaLR(
        optimizer, lambda x: args.lr *
        (1. - float(x) / args.epochs / args.iters_per_epoch)**(args.lr_power))
    lr_scheduler_d = LambdaLR(
        optimizer_d, lambda x:
        (1. - float(x) / args.epochs / args.iters_per_epoch)**(args.lr_power))

    # optionally resume from a checkpoint
    if args.resume:
        checkpoint = torch.load(args.resume, map_location='cpu')
        model.load_state_dict(checkpoint['model'])
        discriminator.load_state_dict(checkpoint['discriminator'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
        optimizer_d.load_state_dict(checkpoint['optimizer_d'])
        lr_scheduler_d.load_state_dict(checkpoint['lr_scheduler_d'])
        args.start_epoch = checkpoint['epoch'] + 1

    # define loss function (criterion)
    criterion = torch.nn.CrossEntropyLoss(
        ignore_index=args.ignore_label).to(device)
    dann = DomainAdversarialEntropyLoss(discriminator)
    interp_train = nn.Upsample(size=args.train_size[::-1],
                               mode='bilinear',
                               align_corners=True)
    interp_val = nn.Upsample(size=args.test_output_size[::-1],
                             mode='bilinear',
                             align_corners=True)

    # define visualization function
    decode = train_source_dataset.decode_target

    def visualize(image, pred, label, prefix):
        """
        Args:
            image (tensor): 3 x H x W
            pred (tensor): C x H x W
            label (tensor): H x W
            prefix: prefix of the saving image
        """
        image = image.detach().cpu().numpy()
        pred = pred.detach().max(dim=0)[1].cpu().numpy()
        label = label.cpu().numpy()
        for tensor, name in [
            (Image.fromarray(np.uint8(DeNormalizeAndTranspose()(image))),
             "image"), (decode(label), "label"), (decode(pred), "pred")
        ]:
            tensor.save(logger.get_image_path("{}_{}.png".format(prefix,
                                                                 name)))

    if args.phase == 'test':
        confmat = validate(val_target_loader, model, interp_val, criterion,
                           visualize, args)
        print(confmat)
        return

    # start training
    best_iou = 0.
    for epoch in range(args.start_epoch, args.epochs):
        logger.set_epoch(epoch)
        print(lr_scheduler.get_lr(), lr_scheduler_d.get_lr())
        # train for one epoch
        train(train_source_iter, train_target_iter, model, interp_train,
              criterion, dann, optimizer, lr_scheduler, optimizer_d,
              lr_scheduler_d, epoch, visualize if args.debug else None, args)

        # evaluate on validation set
        confmat = validate(val_target_loader, model, interp_val, criterion,
                           None, args)
        print(confmat.format(train_source_dataset.classes))
        acc_global, acc, iu = confmat.compute()

        # calculate the mean iou over partial classes
        indexes = [
            train_source_dataset.classes.index(name)
            for name in train_source_dataset.evaluate_classes
        ]
        iu = iu[indexes]
        mean_iou = iu.mean()

        # remember best acc@1 and save checkpoint
        torch.save(
            {
                'model': model.state_dict(),
                'discriminator': discriminator.state_dict(),
                'optimizer': optimizer.state_dict(),
                'optimizer_d': optimizer_d.state_dict(),
                'lr_scheduler': lr_scheduler.state_dict(),
                'lr_scheduler_d': lr_scheduler_d.state_dict(),
                'epoch': epoch,
                'args': args
            }, logger.get_checkpoint_path(epoch))
        if mean_iou > best_iou:
            shutil.copy(logger.get_checkpoint_path(epoch),
                        logger.get_checkpoint_path('best'))
        best_iou = max(best_iou, mean_iou)
        print("Target: {} Best: {}".format(mean_iou, best_iou))

    logger.close()