def train(model, config_train, config_valid):
    max_num_epochs = config_train['max_num_epochs']
    train_set = CocoDataset(cfg=config_train, is_train=True)
    train_loader = DataLoader(train_set,
                              batch_size=config_train['batch_size'],
                              shuffle=True,
                              num_workers=config_train['num_workers'])
    optimizer = opt.Adam(model.parameters(),
                         lr=config_train['base_lr'],
                         weight_decay=5e-4)
    num_iter = 0
    current_epoch = 0
    scheduler = opt.lr_scheduler.MultiStepLR(optimizer,
                                             milestones=[100, 200, 260],
                                             gamma=0.333)
    model = DataParallel(model).to(device)
    model.train()
    for epoch in range(current_epoch, max_num_epochs):
        scheduler.step(epoch=epoch)
        batch_per_iter_idx = 0
        for batched_samples in train_loader:
            if batch_per_iter_idx == 0:
                optimizer.zero_grad()
            images = batched_samples['image'].cuda()
            keypoint_maps = batched_samples['keypoint_map'].cuda()
            depth_maps = batched_samples['depth_map'].cuda()
            offset_maps = batched_samples['offset_map'].cuda()
            # TODO loss
            predictions = model(images)
示例#2
0
def main():
    args = parse_args()

    # Device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Dataset
    train_loader, test_loader, classes = cifar10_data.load_dataset(
        args.dataset_dir, img_show=True)

    model = resnet.resnet18()
    print(model)

    num_classes = 10
    easy_margin = False
    metric_fc = metrics.ArcMarginProduct(512,
                                         num_classes,
                                         s=30,
                                         m=0.5,
                                         easy_margin=easy_margin)

    model.to(device)
    model = DataParallel(model)
    metric_fc.to(device)
    metric_fc = DataParallel(metric_fc)

    lr = 1e-1  # initial learning rate
    lr_step = 10
    weight_decay = 5e-4
    optimizer = torch.optim.SGD([{
        'params': model.parameters()
    }, {
        'params': metric_fc.parameters()
    }],
                                lr=lr,
                                weight_decay=weight_decay)
    scheduler = StepLR(optimizer, step_size=lr_step, gamma=0.1)

    max_epoch = 2
    for i in range(max_epoch):
        scheduler.step()

        model.train()
        for ii, (imgs, labels) in enumerate(train_loader):
            # Set batch data.
            imgs, labels = imgs.to(device), labels.to(device).long()
            feature = model(imgs)
            output = metric_fc(feature, labels)
            loss = criterion(output, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            print(loss)
示例#3
0
    def test_resnet_baseline(self):
        N = 100
        total_iters = 20  # (warmup + benchmark)
        iterations = 4

        target = Variable(torch.randn(N //
                                      5).fill_(1)).type("torch.LongTensor")
        x = Variable(torch.randn(N, 3, 224, 224).fill_(1.0),
                     requires_grad=True)
        # x = Variable(torch.randn(N, 3, 32, 32).fill_(1.0), requires_grad=True)
        # model = resnet_baseline.resnet200()
        # model = resnet_baseline.resnet101()
        model = resnet_baseline.load_resnet()
        model = DataParallel(model)
        # model = resnet_baseline.resnet1001()

        # switch the model to train mode
        model.train()

        # convert the model and input to cuda
        model = model.cuda()
        input_var = x.cuda()
        target_var = target.cuda()

        # declare the optimizer and criterion
        criterion = nn.CrossEntropyLoss().cuda()
        optimizer = torch.optim.SGD(model.parameters(),
                                    0.01,
                                    momentum=0.9,
                                    weight_decay=1e-4)
        optimizer.zero_grad()

        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)
        with cudnn.flags(enabled=True, benchmark=True):
            for i in range(total_iters):
                start.record()
                start_cpu = time.time()
                for j in range(iterations):
                    output = model(input_var)
                    loss = criterion(output, target_var)
                    loss.backward()
                    optimizer.step()

                end_cpu = time.time()
                end.record()
                torch.cuda.synchronize()
                gpu_msec = start.elapsed_time(end)
                print(
                    "Baseline resnet ({:2d}): ({:8.3f} usecs gpu) ({:8.3f} usecs cpu)"
                    .format(i,
                            gpu_msec * 1000, (end_cpu - start_cpu) * 1000000,
                            file=sys.stderr))
示例#4
0
def train(model, train_set, val_set, batch_size, num_epochs, lr, criterion, save_dir : str):
    if(torch.cuda.is_available()):
        if(torch.cuda.device_count() > 1):
            model = DataParallel(model)
        else:
            model = model.cuda(0)
    
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=1, shuffle=True, tranforms=data_transform)
    val_loader = torch.utils.data.DataLoader(val_set, batch_size=1, shuffle=True, tranforms=data_transform)
    
    optimizer = optim.Adam(model.parameters(), lr=lr)

    model.train()
    for epoch in range(num_epochs):
        for batch_idx, batch in enumerate(train_loader):
            lr, hr = batch
            #em yeu anh hihihihihi
            // train, update


            if((batch_idx + 1) % batch_size == 0):
                // update
                // ssim = 
                // update if ssim improve
class PoemImageEmbedTrainer():
    def __init__(self, train_data, test_data, sentiment_model, batchsize, load_model, device):
        self.device = device
        self.train_data = train_data
        self.test_data = test_data
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

        self.train_transform = transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor()
        ])

        self.test_transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor()
        ])

        img_dir = 'data/image'
        self.train_set = PoemImageEmbedDataset(self.train_data, img_dir,
                                               tokenizer=self.tokenizer, max_seq_len=100,
                                               transform=self.train_transform)
        self.train_loader = DataLoader(self.train_set, batch_size=batchsize, shuffle=True, num_workers=4)

        self.test_set = PoemImageEmbedDataset(self.test_data, img_dir,
                                              tokenizer=self.tokenizer, max_seq_len=100,
                                              transform=self.test_transform)
        self.test_loader = DataLoader(self.test_set, batch_size=batchsize, num_workers=4)

        self.model = PoemImageEmbedModel(device)

        self.model = DataParallel(self.model)
        load_dataparallel(self.model.module.img_embedder.sentiment_feature, sentiment_model)
        if load_model:
            logger.info('load model from '+ load_model)
            self.model.load_state_dict(torch.load(load_model))
        self.model.to(device)
        self.optimizer = optim.Adam(list(self.model.module.poem_embedder.linear.parameters()) + \
                                    list(self.model.module.img_embedder.linear.parameters()), lr=1e-4)
        self.scheduler = optim.lr_scheduler.MultiStepLR(self.optimizer, milestones=[2, 4, 6], gamma=0.33)

    def train_epoch(self, epoch, log_interval, save_interval, ckpt_file):
        self.model.train()
        running_ls = 0
        acc_ls = 0
        start = time.time()
        num_batches = len(self.train_loader)
        for i, batch in enumerate(self.train_loader):
            img1, ids1, mask1, img2, ids2, mask2 = [t.to(self.device) for t in batch]
            self.model.zero_grad()
            loss = self.model(img1, ids1, mask1, img2, ids2, mask2)
            loss.backward(torch.ones_like(loss))
            running_ls += loss.mean().item()
            acc_ls += loss.mean().item()
            self.optimizer.step()

            if (i + 1) % log_interval == 0:
                elapsed_time = time.time() - start
                iters_per_sec = (i + 1) / elapsed_time
                remaining = (num_batches - i - 1) / iters_per_sec
                remaining_time = time.strftime("%H:%M:%S", time.gmtime(remaining))

                print('[{:>2}, {:>4}/{}] running loss:{:.4} acc loss:{:.4} {:.3}iters/s {} left'.format(
                    epoch, (i + 1), num_batches, running_ls / log_interval, acc_ls /(i+1),
                    iters_per_sec, remaining_time))
                running_ls = 0

            if (i + 1) % save_interval == 0:
                self.save_model(ckpt_file)

    def save_model(self, file):
        torch.save(self.model.state_dict(), file)
示例#6
0
def main(verbose: int = 1,
         print_freq: int = 100,
         restore: Union[bool, str] = True,
         val_freq: int = 1,
         run_id: str = "model",
         dset_name: str = "memento_frames",
         model_name: str = "frames",
         freeze_until_it: int = 1000,
         additional_metrics: Mapping[str, Callable] = {'rc': rc},
         debug_n: Optional[int] = None,
         batch_size: int = cfg.BATCH_SIZE,
         require_strict_model_load: bool = False,
         restore_optimizer=True,
         optim_string='adam',
         lr=0.01) -> None:

    print("TRAINING MODEL {} ON DATASET {}".format(model_name, dset_name))

    ckpt_savedir = os.path.join(cfg.DATA_SAVEDIR, run_id, cfg.CKPT_DIR)
    print("Saving ckpts to {}".format(ckpt_savedir))
    logs_savepath = os.path.join(cfg.DATA_SAVEDIR, run_id, cfg.LOGDIR)
    print("Saving logs to {}".format(logs_savepath))
    utils.makedirs([ckpt_savedir, logs_savepath])
    last_ckpt_path = os.path.join(ckpt_savedir, "last_model.pth")

    device = utils.set_device()

    print('DEVICE', device)

    # model
    model = get_model(model_name, device)
    # print("model", model)
    model = DataParallel(model)

    # must call this before constructing the optimizer:
    # https://pytorch.org/docs/stable/optim.html
    model.to(device)

    # set up training
    # TODO better one?

    if optim_string == 'adam':
        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    elif optim_string == 'sgd':
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=lr,
                                    momentum=0.9,
                                    weight_decay=0.0001)
    else:
        raise RuntimeError(
            "Unrecognized optimizer string {}".format(optim_string))

    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                   step_size=5,
                                                   gamma=0.1)
    # criterion = MemAlphaLoss(device=device)
    # criterion = MemMSELoss()
    # criterion = lambda x, y: MemMSELoss()(x, y) +
    # CaptionsLoss(device=device)(x, y)
    losses = {
        'mem_mse':
        MemMSELoss(device=device, weights=np.load("memento_weights.npy")),
        'captions':
        CaptionsLoss(device=device,
                     class_weights=cap_utils.get_vocab_weights())
    }

    initial_epoch = 0
    iteration = 0
    unfrozen = False

    if restore:
        ckpt_path = restore if isinstance(restore, str) else last_ckpt_path

        if os.path.exists(ckpt_path):

            print("Restoring weights from {}".format(ckpt_path))

            ckpt = torch.load(ckpt_path)
            utils.try_load_state_dict(model, ckpt['model_state_dict'],
                                      require_strict_model_load)

            if restore_optimizer:
                utils.try_load_optim_state(optimizer,
                                           ckpt['optimizer_state_dict'],
                                           require_strict_model_load)
            initial_epoch = ckpt['epoch']
            iteration = ckpt['it']
    else:
        ckpt_path = last_ckpt_path

    # dataset
    train_ds, val_ds, test_ds = get_dataset(dset_name)
    assert val_ds or test_ds

    if debug_n is not None:
        train_ds = Subset(train_ds, range(debug_n))
        test_ds = Subset(test_ds, range(debug_n))

    train_dl = DataLoader(train_ds,
                          batch_size=batch_size,
                          shuffle=True,
                          num_workers=cfg.NUM_WORKERS)
    test_dl = DataLoader(test_ds,
                         batch_size=batch_size,
                         shuffle=False,
                         num_workers=cfg.NUM_WORKERS)

    # training loop
    start = time.time()

    try:
        for epoch in range(initial_epoch, cfg.NUM_EPOCHS):
            logger = SummaryWriter(logs_savepath)

            # effectively puts the model in train mode.
            # Opposite of model.eval()
            model.train()

            print("Epoch {}".format(epoch))

            for i, (x, y_) in tqdm(enumerate(train_dl),
                                   total=len(train_ds) / batch_size):

                y: ModelOutput[MemModelFields] = ModelOutput(y_)
                iteration += 1

                if not unfrozen and iteration > freeze_until_it:
                    print("Unfreezing encoder")
                    unfrozen = True

                    for param in model.parameters():
                        param.requires_grad = True

                logger.add_scalar('DataTime', time.time() - start, iteration)

                x = x.to(device)
                y = y.to_device(device)

                out = ModelOutput(model(x, y.get_data()))
                loss_vals = {name: l(out, y) for name, l in losses.items()}
                # print("loss_vals", loss_vals)
                loss = torch.stack(list(loss_vals.values()))

                if verbose:
                    print("stacked loss", loss)
                loss = loss.sum()
                # loss = criterion(out, y)

                # I think this zeros out previous gradients (in case people
                # want to accumulate gradients?)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                # logging
                utils.log_loss(logger, loss, loss_vals, iteration)
                logger.add_scalar('ItTime', time.time() - start, iteration)
                start = time.time()

                # display metrics

            # do some validation

            if (epoch + 1) % val_freq == 0:
                print("Validating...")
                model.eval()  # puts model in validation mode
                val_iteration = iteration

                with torch.no_grad():

                    labels: Optional[ModelOutput[MemModelFields]] = None
                    preds: Optional[ModelOutput[MemModelFields]] = None
                    val_losses = []

                    for i, (x, y_) in tqdm(enumerate(test_dl),
                                           total=len(test_ds) / batch_size):
                        val_iteration += 1

                        y = ModelOutput(y_)
                        y_numpy = y.to_numpy()

                        labels = y_numpy if labels is None else labels.merge(
                            y_numpy)

                        x = x.to(device)
                        y = y.to_device(device)

                        out = ModelOutput(model(x, y.get_data()))
                        out_numpy = out.to_device('cpu').to_numpy()
                        preds = out_numpy if preds is None else preds.merge(
                            out_numpy)

                        loss_vals = {
                            name: l(out, y)
                            for name, l in losses.items()
                        }
                        loss = torch.stack(list(loss_vals.values())).sum()
                        utils.log_loss(logger,
                                       loss,
                                       loss_vals,
                                       val_iteration,
                                       phase='val')

                        val_losses.append(loss)

                    print("Calculating validation metric...")
                    # print("preds", {k: v.shape for k, v in preds.items()})
                    # assert False
                    metrics = {
                        fname: f(labels, preds, losses)
                        for fname, f in additional_metrics.items()
                    }
                    print("Validation metrics", metrics)

                    for k, v in metrics.items():
                        if isinstance(v, numbers.Number):
                            logger.add_scalar('Metric_{}'.format(k), v,
                                              iteration)

                    metrics['total_val_loss'] = sum(val_losses)

                    ckpt_path = os.path.join(
                        ckpt_savedir, utils.get_ckpt_path(epoch, metrics))
                    save_ckpt(ckpt_path, model, epoch, iteration, optimizer,
                              dset_name, model_name, metrics)

            # end of epoch
            lr_scheduler.step()

            save_ckpt(last_ckpt_path, model, epoch, iteration, optimizer,
                      dset_name, model_name)

    except KeyboardInterrupt:
        print('Got keyboard interrupt, saving model...')
        save_ckpt(last_ckpt_path, model, epoch, iteration, optimizer,
                  dset_name, model_name)
示例#7
0
def train(args):
    print('start training...')
    model, model_file = create_model(args)
    #model = model.cuda()
    if torch.cuda.device_count() > 1:
        model_name = model.name
        model = DataParallel(model)
        model.name = model_name
    model = model.cuda()

    if args.optim == 'Adam':
        optimizer = optim.Adam(model.parameters(),
                               lr=args.lr,
                               weight_decay=0.0001)
    else:
        optimizer = optim.SGD(model.parameters(),
                              lr=args.lr,
                              momentum=0.9,
                              weight_decay=0.0001)

    if args.lrs == 'plateau':
        lr_scheduler = ReduceLROnPlateau(optimizer,
                                         mode='max',
                                         factor=args.factor,
                                         patience=args.patience,
                                         min_lr=args.min_lr)
    else:
        lr_scheduler = CosineAnnealingLR(optimizer,
                                         args.t_max,
                                         eta_min=args.min_lr)
    #ExponentialLR(optimizer, 0.9, last_epoch=-1) #CosineAnnealingLR(optimizer, 15, 1e-7)

    _, val_loader = get_train_val_loaders(batch_size=args.batch_size,
                                          val_num=args.val_num)

    best_top1_acc = 0.

    print(
        'epoch |    lr    |      %        |  loss  |  avg   |  loss  |  top1  | top10  |  best  | time |  save |'
    )

    if not args.no_first_val:
        top10_acc, best_top1_acc, total_loss = validate(
            args, model, val_loader)
        print(
            'val   |          |               |        |        | {:.4f} | {:.4f} | {:.4f} | {:.4f} |      |       |'
            .format(total_loss, best_top1_acc, top10_acc, best_top1_acc))

    if args.val:
        return

    model.train()

    if args.lrs == 'plateau':
        lr_scheduler.step(best_top1_acc)
    else:
        lr_scheduler.step()
    train_iter = 0

    for epoch in range(args.start_epoch, args.epochs):
        train_loader, val_loader = get_train_val_loaders(
            batch_size=args.batch_size,
            dev_mode=args.dev_mode,
            val_num=args.val_num)

        train_loss = 0

        current_lr = get_lrs(
            optimizer)  #optimizer.state_dict()['param_groups'][2]['lr']
        bg = time.time()
        for batch_idx, data in enumerate(train_loader):
            train_iter += 1
            img, target = data
            img, target = img.cuda(), target.cuda()
            optimizer.zero_grad()
            output = model(img)

            loss = criterion(args, output, target)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            print('\r {:4d} | {:.6f} | {:06d}/{} | {:.4f} | {:.4f} |'.format(
                epoch, float(current_lr[0]), args.batch_size * (batch_idx + 1),
                train_loader.num, loss.item(), train_loss / (batch_idx + 1)),
                  end='')

            if train_iter > 0 and train_iter % args.iter_val == 0:
                top10_acc, top1_acc, total_loss = validate(
                    args, model, val_loader)

                _save_ckp = ''
                if args.always_save or top1_acc > best_top1_acc:
                    best_top1_acc = top1_acc
                    if isinstance(model, DataParallel):
                        torch.save(model.module.state_dict(), model_file)
                    else:
                        torch.save(model.state_dict(), model_file)
                    _save_ckp = '*'
                print(' {:.4f} | {:.4f} | {:.4f} | {:.4f} | {:.2f} |  {:4s} |'.
                      format(total_loss, top1_acc, top10_acc, best_top1_acc,
                             (time.time() - bg) / 60, _save_ckp))

                model.train()

                if args.lrs == 'plateau':
                    lr_scheduler.step(top1_acc)
                else:
                    lr_scheduler.step()
                current_lr = get_lrs(optimizer)
示例#8
0
class BaseTrainer(object):
    def __init__(self, args):
        self.args = args
        self.set_random_seed(random_seed=args.random_seed)
        self.bert = BertModel.from_pretrained("bert-base-uncased")
        #self.bert = BertModel(BertConfig())

        self.qa_outputs = nn.Linear(self.args.hidden_size, 2).to('cuda')
        # init weight
        self.tokenizer = BertTokenizer.from_pretrained(
            args.bert_model, do_lower_case=args.do_lower_case)
        if args.debug:
            print("Debugging mode on.")
        self.features_lst = self.get_features(self.args.train_folder,
                                              self.args.debug)

    def estimate_fisher(self, data_loader, sample_size, batch_size=32):
        # sample loglikelihoods from the dataset.

        loglikelihoods = []

        for i, batch in enumerate(data_loader, start=1):
            input_ids, input_mask, seg_ids, start_positions, end_positions, _ = batch
            seq_len = torch.sum(torch.sign(input_ids), 1).detach()
            max_len = torch.max(seq_len).detach()

            #if self.args.use_cuda:
            #    input_ids = input_ids.cuda(self.args.gpu, non_blocking=True)
            #    input_mask = input_mask.cuda(self.args.gpu, non_blocking=True)
            #    seg_ids = seg_ids.cuda(self.args.gpu, non_blocking=True)
            #    start_positions = start_positions.cuda(self.args.gpu, non_blocking=True)
            #    end_positions = end_positions.cuda(self.args.gpu, non_blocking=True)

            input_ids = input_ids[:, :max_len].cuda(self.args.gpu,
                                                    non_blocking=True)
            input_mask = input_mask[:, :max_len].cuda(self.args.gpu,
                                                      non_blocking=True)
            seg_ids = seg_ids[:, :max_len].cuda(self.args.gpu,
                                                non_blocking=True)
            start_positions = start_positions.cuda(self.args.gpu,
                                                   non_blocking=True)
            end_positions = end_positions.cuda(self.args.gpu,
                                               non_blocking=True)

            try:
                model = self.bert.to('cuda')
                model = nn.DataParallel(model)
                x = model(input_ids,
                          attention_mask=input_mask,
                          token_type_ids=seg_ids)[0]
                x = torch.stack(x)

                logits = self.qa_outputs(x)
            except RuntimeError as exception:
                if "out of memory" in str(exception):
                    print("WARNING: out of memory")
                    if hasattr(torch.cuda, 'empty_cache'):
                        torch.cuda.empty_cache()
                else:
                    raise exception

            log_prob = F.log_softmax(logits, dim=0)
            #log_prob = F.log_softmax(torch.rand(len(seq_len),1), dim=0)
            loglikelihoods.append(log_prob)
            gc.collect()
            #F.log_softmax(self(x), dim=1)[range(batch_size), y.data]
            #)
        print(loglikelihoods)
        # estimate the fisher information of the parameters.
        #         loglikelihoods = torch.cat(loglikelihoods).unbind()
        #         loglikelihoods = torch.stack(loglikelihoods,requires_grad=True).to('cuda')
        loglikelihoods = torch.tensor(loglikelihoods)
        loglikelihoods = Variable(loglikelihoods, requires_grad=True)
        loglikelihood_grads = zip(*[
            autograd.grad(l,
                          self.model.parameters(),
                          retain_graph=(i < len(loglikelihoods)))
            for i, l in enumerate(loglikelihoods, 1)
        ])
        loglikelihood_grads = [torch.stack(gs) for gs in loglikelihood_grads]
        fisher_diagonals = [(g**2).mean(0) for g in loglikelihood_grads]
        param_names = [
            n.replace('.', '__') for n, p in self.named_parameters()
        ]
        return {n: f.detach() for n, f in zip(param_names, fisher_diagonals)}

    def consolidate(self, fisher):
        for n, p in self.named_parameters():
            n = n.replace('.', '__')
            self.register_buffer('{}_mean'.format(n), p.data.clone())
            self.register_buffer('{}_fisher'.format(n), fisher[n].data.clone())

    def ewc_loss(self, cuda=False):
        try:
            losses = []
            for n, p in self.named_parameters():
                # retrieve the consolidated mean and fisher information.
                n = n.replace('.', '__')
                mean = getattr(self, '{}_mean'.format(n))
                fisher = getattr(self, '{}_fisher'.format(n))
                # wrap mean and fisher in variables.
                mean = Variable(mean)
                fisher = Variable(fisher)
                # calculate a ewc loss. (assumes the parameter's prior as
                # gaussian distribution with the estimated mean and the
                # estimated cramer-rao lower bound variance, which is
                # equivalent to the inverse of fisher information)
                losses.append((fisher * (p - mean)**2).sum())
            return 20 * sum(losses)
        except AttributeError:
            # ewc loss is 0 if there's no consolidated parameters.
            return (Variable(torch.zeros(1)).cuda() if cuda else Variable(
                torch.zeros(1)))

    def _is_on_cuda(self):
        return next(self.parameters()).is_cuda

    def make_model_env(self, gpu, ngpus_per_node):
        if self.args.distributed:
            self.args.gpu = self.args.devices[gpu]
        else:
            self.args.gpu = 0

        if self.args.use_cuda and self.args.distributed:
            # For multiprocessing distributed training, rank needs to be the
            # global rank among all the processes
            self.args.rank = self.args.rank * ngpus_per_node + gpu
            dist.init_process_group(backend=self.args.dist_backend,
                                    init_method=self.args.dist_url,
                                    world_size=self.args.world_size,
                                    rank=self.args.rank)

        # Load baseline model
        self.model = BertForQuestionAnswering.from_pretrained(
            self.args.bert_model)

        if self.args.load_model is not None:
            print("Loading model from ", self.args.load_model)
            self.model.load_state_dict(
                torch.load(self.args.load_model,
                           map_location=lambda storage, loc: storage))

        max_len = max([len(f) for f in self.features_lst])
        num_train_optimization_steps = math.ceil(
            max_len / self.args.batch_size) * self.args.epochs * len(
                self.features_lst)

        if self.args.freeze_bert:
            for param in self.model.bert.parameters():
                param.requires_grad = False

        self.optimizer = get_opt(list(self.model.named_parameters()),
                                 num_train_optimization_steps, self.args)

        if self.args.use_cuda:
            if self.args.distributed:
                torch.cuda.set_device(self.args.gpu)
                self.model.cuda(self.args.gpu)
                self.args.batch_size = int(self.args.batch_size /
                                           ngpus_per_node)
                self.args.workers = int(
                    (self.args.workers + ngpus_per_node - 1) / ngpus_per_node)
                self.model = DistributedDataParallel(
                    self.model,
                    device_ids=[self.args.gpu],
                    find_unused_parameters=True)
            else:
                self.model.cuda()
                self.model = DataParallel(self.model,
                                          device_ids=self.args.devices)

        cudnn.benchmark = True

    def make_run_env(self):
        if self.args.distributed:
            # distributing dev file evaluation task
            self.dev_files = []
            gpu_num = len(self.args.devices)
            files = os.listdir(self.args.dev_folder)
            for i in range(len(files)):
                if i % gpu_num == self.args.rank:
                    self.dev_files.append(files[i])

            print("GPU {}".format(self.args.gpu), self.dev_files)
        else:
            self.dev_files = os.listdir(self.args.dev_folder)
            print(self.dev_files)

    def get_features(self, train_folder, debug=False):
        pickled_folder = self.args.pickled_folder + "_{}_{}".format(
            self.args.bert_model, str(self.args.skip_no_ans))

        features_lst = []

        files = [f for f in os.listdir(train_folder) if f.endswith(".gz")]
        print("Number of data set:{}".format(len(files)))
        for filename in files:
            data_name = filename.split(".")[0]
            # Check whether pkl file already exists
            pickle_file_name = '{}.pkl'.format(data_name)
            pickle_file_path = os.path.join(pickled_folder, pickle_file_name)
            if os.path.exists(pickle_file_path):
                with open(pickle_file_path, 'rb') as pkl_f:
                    print("Loading {} file as pkl...".format(data_name))
                    features_lst.append(pickle.load(pkl_f))
            else:
                print("processing {} file".format(data_name))
                file_path = os.path.join(train_folder, filename)

                train_examples = read_squad_examples(file_path, debug=debug)

                train_features = convert_examples_to_features(
                    examples=train_examples,
                    tokenizer=self.tokenizer,
                    max_seq_length=self.args.max_seq_length,
                    max_query_length=self.args.max_query_length,
                    doc_stride=self.args.doc_stride,
                    is_training=True,
                    skip_no_ans=self.args.skip_no_ans)

                features_lst.append(train_features)

                # Save feature lst as pickle (For reuse & fast loading)
                if not debug and self.args.rank == 0:
                    with open(pickle_file_path, 'wb') as pkl_f:
                        print("Saving {} file from pkl file...".format(
                            data_name))
                        pickle.dump(train_features, pkl_f)

        return features_lst

    def get_data_loader(self, features_lst, args):

        all_input_ids = []
        all_input_mask = []
        all_segment_ids = []
        all_start_positions = []
        all_end_positions = []
        all_labels = []

        for i, train_features in enumerate(features_lst):
            all_input_ids.append(
                torch.tensor([f.input_ids for f in train_features],
                             dtype=torch.long))
            all_input_mask.append(
                torch.tensor([f.input_mask for f in train_features],
                             dtype=torch.long))
            all_segment_ids.append(
                torch.tensor([f.segment_ids for f in train_features],
                             dtype=torch.long))

            start_positions = torch.tensor(
                [f.start_position for f in train_features], dtype=torch.long)
            end_positions = torch.tensor(
                [f.end_position for f in train_features], dtype=torch.long)

            all_start_positions.append(start_positions)
            all_end_positions.append(end_positions)
            all_labels.append(i * torch.ones_like(start_positions))

        all_input_ids = torch.cat(all_input_ids, dim=0)
        all_input_mask = torch.cat(all_input_mask, dim=0)
        all_segment_ids = torch.cat(all_segment_ids, dim=0)
        all_start_positions = torch.cat(all_start_positions, dim=0)
        all_end_positions = torch.cat(all_end_positions, dim=0)
        all_labels = torch.cat(all_labels, dim=0)

        train_data = TensorDataset(all_input_ids, all_input_mask,
                                   all_segment_ids, all_start_positions,
                                   all_end_positions, all_labels)

        if args.distributed:
            train_sampler = DistributedSampler(train_data)
            data_loader = DataLoader(train_data,
                                     num_workers=args.workers,
                                     pin_memory=True,
                                     sampler=train_sampler,
                                     batch_size=args.batch_size)
        else:
            weights = make_weights_for_balanced_classes(
                all_labels.detach().cpu().numpy().tolist(),
                self.args.num_classes)
            weights = torch.DoubleTensor(weights)
            train_sampler = torch.utils.data.sampler.WeightedRandomSampler(
                weights, len(weights))
            data_loader = torch.utils.data.DataLoader(
                train_data,
                batch_size=args.batch_size,
                shuffle=None,
                sampler=train_sampler,
                num_workers=args.workers,
                worker_init_fn=self.set_random_seed(self.args.random_seed),
                pin_memory=True,
                drop_last=True)
        return data_loader

    def get_iter(self, features_lst, args):
        all_input_ids = []
        all_input_mask = []
        all_segment_ids = []
        all_start_positions = []
        all_end_positions = []
        all_labels = []

        for i, train_features in enumerate(features_lst):
            all_input_ids.append(
                torch.tensor([f.input_ids for f in train_features],
                             dtype=torch.long))
            all_input_mask.append(
                torch.tensor([f.input_mask for f in train_features],
                             dtype=torch.long))
            all_segment_ids.append(
                torch.tensor([f.segment_ids for f in train_features],
                             dtype=torch.long))

            start_positions = torch.tensor(
                [f.start_position for f in train_features], dtype=torch.long)
            end_positions = torch.tensor(
                [f.end_position for f in train_features], dtype=torch.long)

            all_start_positions.append(start_positions)
            all_end_positions.append(end_positions)
            all_labels.append(i * torch.ones_like(start_positions))

        all_input_ids = torch.cat(all_input_ids, dim=0)
        all_input_mask = torch.cat(all_input_mask, dim=0)
        all_segment_ids = torch.cat(all_segment_ids, dim=0)
        all_start_positions = torch.cat(all_start_positions, dim=0)
        all_end_positions = torch.cat(all_end_positions, dim=0)
        all_labels = torch.cat(all_labels, dim=0)

        train_data = TensorDataset(all_input_ids, all_input_mask,
                                   all_segment_ids, all_start_positions,
                                   all_end_positions, all_labels)
        if args.distributed:
            train_sampler = DistributedSampler(train_data)
            data_loader = DataLoader(train_data,
                                     num_workers=args.workers,
                                     pin_memory=True,
                                     sampler=train_sampler,
                                     batch_size=args.batch_size)
        else:
            weights = make_weights_for_balanced_classes(
                all_labels.detach().cpu().numpy().tolist(),
                self.args.num_classes)
            weights = torch.DoubleTensor(weights)
            train_sampler = torch.utils.data.sampler.WeightedRandomSampler(
                weights, len(weights))
            data_loader = torch.utils.data.DataLoader(
                train_data,
                batch_size=args.batch_size,
                shuffle=None,
                sampler=train_sampler,
                num_workers=args.workers,
                worker_init_fn=self.set_random_seed(self.args.random_seed),
                pin_memory=True,
                drop_last=True)

        return data_loader, train_sampler

    def save_model(self, epoch, loss):
        loss = round(loss, 3)
        model_type = ("adv" if self.args.adv else "base")

        save_file = os.path.join(
            self.args.save_dir,
            "{}_{}_{:.3f}.pt".format(model_type, epoch, loss))
        save_file_config = os.path.join(
            self.args.save_dir,
            "{}_config_{}_{:.3f}.json".format(model_type, epoch, loss))

        model_to_save = self.model.module if hasattr(
            self.model,
            'module') else self.model  # Only save the model it-self

        torch.save(model_to_save.state_dict(), save_file)
        model_to_save.config.to_json_file(save_file_config)

    def train(self, consolidate=True, fisher_estimation_sample_size=1024):
        step = 1
        avg_loss = 0
        global_step = 1
        iter_lst = [self.get_iter(self.features_lst, self.args)]
        num_batches = sum([len(iterator[0]) for iterator in iter_lst])
        for epoch in range(self.args.start_epoch,
                           self.args.start_epoch + self.args.epochs):
            self.model.train()
            start = time.time()
            batch_step = 1
            for data_loader, sampler in iter_lst:
                if self.args.distributed:
                    sampler.set_epoch(epoch)

                for i, batch in enumerate(data_loader, start=1):
                    input_ids, input_mask, seg_ids, start_positions, end_positions, _ = batch

                    # remove unnecessary pad token
                    seq_len = torch.sum(torch.sign(input_ids), 1)
                    max_len = torch.max(seq_len)

                    input_ids = input_ids[:, :max_len].clone()
                    input_mask = input_mask[:, :max_len].clone()
                    seg_ids = seg_ids[:, :max_len].clone()
                    start_positions = start_positions.clone()
                    end_positions = end_positions.clone()

                    if self.args.use_cuda:
                        input_ids = input_ids.cuda(self.args.gpu,
                                                   non_blocking=True)
                        input_mask = input_mask.cuda(self.args.gpu,
                                                     non_blocking=True)
                        seg_ids = seg_ids.cuda(self.args.gpu,
                                               non_blocking=True)
                        start_positions = start_positions.cuda(
                            self.args.gpu, non_blocking=True)
                        end_positions = end_positions.cuda(self.args.gpu,
                                                           non_blocking=True)

                    loss = self.model(input_ids, seg_ids, input_mask,
                                      start_positions, end_positions)
                    loss = loss.mean()
                    loss = loss / self.args.gradient_accumulation_steps

                    ewc_loss = self.ewc_loss(cuda=True)
                    loss = loss + ewc_loss

                    loss.backward()

                    avg_loss = self.cal_running_avg_loss(
                        loss.item() * self.args.gradient_accumulation_steps,
                        avg_loss)
                    if step % self.args.gradient_accumulation_steps == 0:
                        self.optimizer.step()
                        self.optimizer.zero_grad()

                    if epoch != 0 and i % 2000 == 0:
                        result_dict = self.evaluate_model(i)
                        for dev_file, f1 in result_dict.items():
                            print("GPU/CPU {} evaluated {}: {:.2f}".format(
                                self.args.gpu, dev_file, f1),
                                  end="\n")

                    global_step += 1
                    batch_step += 1
                    msg = "{}/{} {} - ETA : {} - loss: {:.4f}" \
                        .format(batch_step, num_batches, progress_bar(batch_step, num_batches),
                                eta(start, batch_step, num_batches),
                                avg_loss)
                    print(msg, end="\r")

            print("[GPU Num: {}, epoch: {}, Final loss: {:.4f}]".format(
                self.args.gpu, epoch, avg_loss))

            # save model
            if self.args.rank == 0:
                self.save_model(epoch, avg_loss)

            if self.args.do_valid:
                result_dict = self.evaluate_model(epoch)
                for dev_file, f1 in result_dict.items():
                    print("GPU/CPU {} evaluated {}: {:.2f}".format(
                        self.args.gpu, dev_file, f1),
                          end="\n")

        if consolidate:
            # estimate the fisher information of the parameters and consolidate
            # them in the network.
            print(
                '=> Estimating diagonals of the fisher information matrix...',
                flush=True,
                end='',
            )
            # ATTENTION!!! the data_loader should entire training set!!!!
            self.consolidate(
                self.estimate_fisher(
                    self.get_data_loader(self.features_lst, self.args),
                    fisher_estimation_sample_size))
            print('EWC Loaded!')

    def evaluate_model(self, epoch):
        # result directory
        result_file = os.path.join(self.args.result_dir,
                                   "dev_eval_{}.txt".format(epoch))
        fw = open(result_file, "a")
        result_dict = dict()
        for dev_file in self.dev_files:
            file_name = dev_file.split(".")[0]
            prediction_file = os.path.join(
                self.args.result_dir,
                "epoch_{}_{}.json".format(epoch, file_name))
            file_path = os.path.join(self.args.dev_folder, dev_file)
            metrics = eval_qa(self.model,
                              file_path,
                              prediction_file,
                              args=self.args,
                              tokenizer=self.tokenizer,
                              batch_size=self.args.batch_size)
            f1 = metrics["f1"]
            fw.write("{} : {}\n".format(file_name, f1))
            result_dict[dev_file] = f1
        fw.close()

        return result_dict

    @staticmethod
    def cal_running_avg_loss(loss, running_avg_loss, decay=0.99):
        if running_avg_loss == 0:
            return loss
        else:
            running_avg_loss = running_avg_loss * decay + (1 - decay) * loss
            return running_avg_loss

    @staticmethod
    def set_random_seed(random_seed):
        if random_seed is not None:
            print("Set random seed as {}".format(random_seed))
            os.environ['PYTHONHASHSEED'] = str(random_seed)
            random.seed(random_seed)
            np.random.seed(random_seed)
            torch.manual_seed(random_seed)
            torch.cuda.manual_seed_all(random_seed)
            torch.set_num_threads(1)
            cudnn.benchmark = False
            cudnn.deterministic = True
            warnings.warn('You have chosen to seed training. '
                          'This will turn on the CUDNN deterministic setting, '
                          'which can slow down your training considerably! '
                          'You may see unexpected behavior when restarting '
                          'from checkpoints.')
示例#9
0
    ########## summary ##########
    print('[%d/%d] - time: %.2f, loss_d: %.3f, loss_g: %.3f' %
          ((epoch + 1), EPOCH, per_epoch_time,
           torch.mean(torch.FloatTensor(D_losses)),
           torch.mean(torch.FloatTensor(G_losses))))

    ########## test ##########
    G.eval()  # stop train and start test
    z = torch.randn((5*5, 100)).view(-1, 100, 1, 1)
    if GPU_MODE:
        z = Variable(z.cuda(), volatile=True)
    else:
        z = Variable(z, volatile=True)
    random_image = G(z)
    fixed_image = G(fixed_z)
    G.train()  # stop test and start train

    p = DIR + '/Random_results/MNIST_GAN_' + str(epoch + 1) + '.png'
    fixed_p = DIR + '/Fixed_results/MNIST_GAN_' + str(epoch + 1) + '.png'
    utils.save_result(random_image, (epoch+1), save=True, path=p)
    utils.save_result(fixed_image, (epoch+1), save=True, path=fixed_p)
    train_hist['D_losses'].append(torch.mean(torch.FloatTensor(D_losses)))
    train_hist['G_losses'].append(torch.mean(torch.FloatTensor(G_losses)))
    train_hist['per_epoch_times'].append(per_epoch_time)

end_time = time()
total_time = end_time - end_time
print("Avg per epoch time: %.2f, total %d epochs time: %.2f" % (torch.mean(torch.FloatTensor(train_hist['per_epoch_times'])), EPOCH, total_time))
print("Training finish!!!...")

# save parameters
示例#10
0
    def train(self):

        if self.net == 'vgg16':
            photo_net = DataParallel(self._get_vgg16()).cuda()
            sketch_net = DataParallel(self._get_vgg16()).cuda()
        elif self.net == 'resnet34':
            photo_net = DataParallel(self._get_resnet34()).cuda()
            sketch_net = DataParallel(self._get_resnet34()).cuda()
        elif self.net == 'resnet50':
            photo_net = DataParallel(self._get_resnet50()).cuda()
            sketch_net = DataParallel(self._get_resnet50()).cuda()

        if self.fine_tune:
            photo_net_root = self.model_root
            sketch_net_root = self.model_root.replace('photo', 'sketch')

            photo_net.load_state_dict(
                t.load(photo_net_root, map_location=t.device('cpu')))
            sketch_net.load_state_dict(
                t.load(sketch_net_root, map_location=t.device('cpu')))

        print('net')
        print(photo_net)

        # triplet_loss = nn.TripletMarginLoss(margin=self.margin, p=self.p).cuda()
        photo_cat_loss = nn.CrossEntropyLoss().cuda()
        sketch_cat_loss = nn.CrossEntropyLoss().cuda()

        my_triplet_loss = TripletLoss().cuda()

        # optimizer
        photo_optimizer = t.optim.Adam(photo_net.parameters(), lr=self.lr)
        sketch_optimizer = t.optim.Adam(sketch_net.parameters(), lr=self.lr)

        if self.vis:
            vis = Visualizer(self.env)

        triplet_loss_meter = AverageValueMeter()
        sketch_cat_loss_meter = AverageValueMeter()
        photo_cat_loss_meter = AverageValueMeter()

        data_loader = TripleDataLoader(self.dataloader_opt)
        dataset = data_loader.load_data()

        for epoch in range(self.epochs):

            print('---------------{0}---------------'.format(epoch))

            if self.test and epoch % self.test_f == 0:

                tester_config = Config()
                tester_config.test_bs = 128
                tester_config.photo_net = photo_net
                tester_config.sketch_net = sketch_net

                tester_config.photo_test = self.photo_test
                tester_config.sketch_test = self.sketch_test

                tester = Tester(tester_config)
                test_result = tester.test_instance_recall()

                result_key = list(test_result.keys())
                vis.plot('recall',
                         np.array([
                             test_result[result_key[0]],
                             test_result[result_key[1]]
                         ]),
                         legend=[result_key[0], result_key[1]])
                if self.save_model:
                    t.save(
                        photo_net.state_dict(), self.save_dir + '/photo' +
                        '/photo_' + self.net + '_%s.pth' % epoch)
                    t.save(
                        sketch_net.state_dict(), self.save_dir + '/sketch' +
                        '/sketch_' + self.net + '_%s.pth' % epoch)

            photo_net.train()
            sketch_net.train()

            for ii, data in enumerate(dataset):

                photo_optimizer.zero_grad()
                sketch_optimizer.zero_grad()

                photo = data['P'].cuda()
                sketch = data['S'].cuda()
                label = data['L'].cuda()

                p_cat, p_feature = photo_net(photo)
                s_cat, s_feature = sketch_net(sketch)

                # category loss
                p_cat_loss = photo_cat_loss(p_cat, label)
                s_cat_loss = sketch_cat_loss(s_cat, label)

                photo_cat_loss_meter.add(p_cat_loss.item())
                sketch_cat_loss_meter.add(s_cat_loss.item())

                # triplet loss
                loss = p_cat_loss + s_cat_loss

                # tri_record = 0.
                '''
                for i in range(self.batch_size):
                    # negative
                    negative_feature = t.cat([p_feature[0:i, :], p_feature[i + 1:, :]], dim=0)
                    # print('negative_feature.size :', negative_feature.size())
                    # photo_feature
                    anchor_feature = s_feature[i, :]
                    anchor_feature = anchor_feature.expand_as(negative_feature)
                    # print('anchor_feature.size :', anchor_feature.size())

                    # positive
                    positive_feature = p_feature[i, :]
                    positive_feature = positive_feature.expand_as(negative_feature)
                    # print('positive_feature.size :', positive_feature.size())

                    tri_loss = triplet_loss(anchor_feature, positive_feature, negative_feature)

                    tri_record = tri_record + tri_loss

                    # print('tri_loss :', tri_loss)
                    loss = loss + tri_loss
                '''
                # print('tri_record : ', tri_record)

                my_tri_loss = my_triplet_loss(
                    s_feature, p_feature) / (self.batch_size - 1)
                triplet_loss_meter.add(my_tri_loss.item())
                # print('my_tri_loss : ', my_tri_loss)

                # print(tri_record - my_tri_loss)
                loss = loss + my_tri_loss
                # print('loss :', loss)
                # loss = loss / opt.batch_size

                loss.backward()

                photo_optimizer.step()
                sketch_optimizer.step()

                if self.vis:
                    vis.plot('triplet_loss',
                             np.array([
                                 triplet_loss_meter.value()[0],
                                 photo_cat_loss_meter.value()[0],
                                 sketch_cat_loss_meter.value()[0]
                             ]),
                             legend=[
                                 'triplet_loss', 'photo_cat_loss',
                                 'sketch_cat_loss'
                             ])

                triplet_loss_meter.reset()
                photo_cat_loss_meter.reset()
                sketch_cat_loss_meter.reset()
示例#11
0
def train_model(train_dataset, train_num_each, val_dataset, val_num_each):
    num_train = len(train_dataset)
    num_val = len(val_dataset)

    train_useful_start_idx = get_useful_start_idx(sequence_length,
                                                  train_num_each)

    val_useful_start_idx = get_useful_start_idx(sequence_length, val_num_each)

    num_train_we_use = len(train_useful_start_idx) // num_gpu * num_gpu
    num_val_we_use = len(val_useful_start_idx) // num_gpu * num_gpu
    # num_train_we_use = 4
    # num_val_we_use = 800

    train_we_use_start_idx = train_useful_start_idx[0:num_train_we_use]
    val_we_use_start_idx = val_useful_start_idx[0:num_val_we_use]

    train_idx = []
    for i in range(num_train_we_use):
        for j in range(sequence_length):
            train_idx.append(train_we_use_start_idx[i] + j)

    val_idx = []
    for i in range(num_val_we_use):
        for j in range(sequence_length):
            val_idx.append(val_we_use_start_idx[i] + j)

    num_train_all = len(train_idx)
    num_val_all = len(val_idx)

    print('num train start idx : {:6d}'.format(len(train_useful_start_idx)))
    print('last idx train start: {:6d}'.format(train_useful_start_idx[-1]))
    print('num of train dataset: {:6d}'.format(num_train))
    print('num of train we use : {:6d}'.format(num_train_we_use))
    print('num of all train use: {:6d}'.format(num_train_all))
    print('num valid start idx : {:6d}'.format(len(val_useful_start_idx)))
    print('last idx valid start: {:6d}'.format(val_useful_start_idx[-1]))
    print('num of valid dataset: {:6d}'.format(num_val))
    print('num of valid we use : {:6d}'.format(num_val_we_use))
    print('num of all valid use: {:6d}'.format(num_val_all))

    train_loader = DataLoader(train_dataset,
                              batch_size=train_batch_size,
                              sampler=train_idx,
                              num_workers=workers,
                              pin_memory=False)
    val_loader = DataLoader(val_dataset,
                            batch_size=val_batch_size,
                            sampler=val_idx,
                            num_workers=workers,
                            pin_memory=False)
    model = multi_lstm_4loss()
    sig_f = nn.Sigmoid()

    if use_gpu:
        model = model.cuda()
        sig_f = sig_f.cuda()
    model = DataParallel(model)
    criterion_1 = nn.BCEWithLogitsLoss(size_average=False)
    criterion_2 = nn.CrossEntropyLoss(size_average=False)

    if multi_optim == 0:
        if optimizer_choice == 0:
            optimizer = optim.SGD(model.parameters(),
                                  lr=learning_rate,
                                  momentum=momentum,
                                  dampening=dampening,
                                  weight_decay=weight_decay,
                                  nesterov=use_nesterov)
            if sgd_adjust_lr == 0:
                exp_lr_scheduler = lr_scheduler.StepLR(optimizer,
                                                       step_size=sgd_adjust_lr,
                                                       gamma=sgd_gamma)
            elif sgd_adjust_lr == 1:
                exp_lr_scheduler = lr_scheduler.ReduceLROnPlateau(
                    optimizer, 'min')
        elif optimizer_choice == 1:
            optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    elif multi_optim == 1:
        if optimizer_choice == 0:
            optimizer = optim.SGD([
                {
                    'params': model.module.share.parameters()
                },
                {
                    'params': model.module.lstm.parameters(),
                    'lr': learning_rate
                },
                {
                    'params': model.module.fc_tool.parameters(),
                    'lr': learning_rate
                },
                {
                    'params': model.module.fc_tool1.parameters(),
                    'lr': learning_rate
                },
                {
                    'params': model.module.fc_tool2.parameters(),
                    'lr': learning_rate
                },
                {
                    'params': model.module.fc_phase1.parameters(),
                    'lr': learning_rate
                },
                {
                    'params': model.module.fc_phase2.parameters(),
                    'lr': learning_rate
                },
            ],
                                  lr=learning_rate / 10,
                                  momentum=momentum,
                                  dampening=dampening,
                                  weight_decay=weight_decay,
                                  nesterov=use_nesterov)
            if sgd_adjust_lr == 0:
                exp_lr_scheduler = lr_scheduler.StepLR(optimizer,
                                                       step_size=sgd_adjust_lr,
                                                       gamma=sgd_gamma)
            elif sgd_adjust_lr == 1:
                exp_lr_scheduler = lr_scheduler.ReduceLROnPlateau(
                    optimizer, 'min')
        elif optimizer_choice == 1:
            optimizer = optim.Adam([
                {
                    'params': model.module.share.parameters()
                },
                {
                    'params': model.module.lstm.parameters(),
                    'lr': learning_rate
                },
                {
                    'params': model.module.fc_tool.parameters(),
                    'lr': learning_rate
                },
                {
                    'params': model.module.fc_tool1.parameters(),
                    'lr': learning_rate
                },
                {
                    'params': model.module.fc_tool2.parameters(),
                    'lr': learning_rate
                },
                {
                    'params': model.module.fc_phase1.parameters(),
                    'lr': learning_rate
                },
                {
                    'params': model.module.fc_phase2.parameters(),
                    'lr': learning_rate
                },
            ],
                                   lr=learning_rate / 10)

    best_model_wts = copy.deepcopy(model.state_dict())
    best_val_accuracy_1 = 0.0
    best_val_accuracy_2 = 0.0  # judge by accu2
    correspond_train_acc_1 = 0.0
    correspond_train_acc_2 = 0.0

    record_np = np.zeros([epochs, 8])

    for epoch in range(epochs):
        # np.random.seed(epoch)
        np.random.shuffle(train_we_use_start_idx)
        train_idx = []
        for i in range(num_train_we_use):
            for j in range(sequence_length):
                train_idx.append(train_we_use_start_idx[i] + j)

        train_loader = DataLoader(train_dataset,
                                  batch_size=train_batch_size,
                                  sampler=train_idx,
                                  num_workers=workers,
                                  pin_memory=False)

        model.train()
        train_loss_11 = 0.0
        train_loss_12 = 0.0
        train_loss_21 = 0.0
        train_loss_22 = 0.0
        train_corrects_11 = 0
        train_corrects_12 = 0
        train_corrects_21 = 0
        train_corrects_22 = 0

        train_start_time = time.time()
        for data in train_loader:
            inputs, labels_1, labels_2 = data
            if use_gpu:
                inputs = Variable(inputs.cuda())
                labels_1 = Variable(labels_1.cuda())
                labels_2 = Variable(labels_2.cuda())
            else:
                inputs = Variable(inputs)
                labels_1 = Variable(labels_1)
                labels_2 = Variable(labels_2)

            optimizer.zero_grad()

            outputs_11, outputs_12, outputs_21, outputs_22 = model.forward(
                inputs)

            _, preds_12 = torch.max(outputs_12.data, 1)
            _, preds_22 = torch.max(outputs_22.data, 1)

            sig_out_11 = sig_f(outputs_11.data)
            sig_out_21 = sig_f(outputs_21.data)

            preds_11 = torch.ByteTensor(sig_out_11.cpu() > 0.5)
            preds_11 = preds_11.long()
            train_corrects_11 += torch.sum(preds_11 == labels_1.data.cpu())
            preds_21 = torch.ByteTensor(sig_out_21.cpu() > 0.5)
            preds_21 = preds_21.long()
            train_corrects_21 += torch.sum(preds_21 == labels_1.data.cpu())

            labels_1 = Variable(labels_1.data.float())
            loss_11 = criterion_1(outputs_11, labels_1)
            loss_21 = criterion_1(outputs_21, labels_1)

            loss_12 = criterion_2(outputs_12, labels_2)
            loss_22 = criterion_2(outputs_22, labels_2)
            loss = loss_11 + loss_12 + loss_21 + loss_22
            loss.backward()
            optimizer.step()

            train_loss_11 += loss_11.data[0]
            train_loss_12 += loss_12.data[0]
            train_loss_21 += loss_21.data[0]
            train_loss_22 += loss_22.data[0]
            train_corrects_12 += torch.sum(preds_12 == labels_2.data)
            train_corrects_22 += torch.sum(preds_22 == labels_2.data)

        train_elapsed_time = time.time() - train_start_time
        train_accuracy_11 = train_corrects_11 / num_train_all / 7
        train_accuracy_21 = train_corrects_21 / num_train_all / 7
        train_accuracy_12 = train_corrects_12 / num_train_all
        train_accuracy_22 = train_corrects_22 / num_train_all
        train_average_loss_11 = train_loss_11 / num_train_all / 7
        train_average_loss_21 = train_loss_21 / num_train_all / 7
        train_average_loss_12 = train_loss_12 / num_train_all
        train_average_loss_22 = train_loss_22 / num_train_all

        # begin eval

        model.eval()
        val_loss_11 = 0.0
        val_loss_12 = 0.0
        val_loss_21 = 0.0
        val_loss_22 = 0.0
        val_corrects_11 = 0
        val_corrects_12 = 0
        val_corrects_21 = 0
        val_corrects_22 = 0

        val_start_time = time.time()
        for data in val_loader:
            inputs, labels_1, labels_2 = data
            labels_2 = labels_2[(sequence_length - 1)::sequence_length]
            if use_gpu:
                inputs = Variable(inputs.cuda(), volatile=True)
                labels_1 = Variable(labels_1.cuda(), volatile=True)
                labels_2 = Variable(labels_2.cuda(), volatile=True)
            else:
                inputs = Variable(inputs, volatile=True)
                labels_1 = Variable(labels_1, volatile=True)
                labels_2 = Variable(labels_2, volatile=True)

            # if crop_type == 0 or crop_type == 1:
            #     outputs_1, outputs_2 = model.forward(inputs)
            # elif crop_type == 5:
            #     inputs = inputs.permute(1, 0, 2, 3, 4).contiguous()
            #     inputs = inputs.view(-1, 3, 224, 224)
            #     outputs_1, outputs_2 = model.forward(inputs)
            #     outputs_1 = outputs_1.view(5, -1, 7)
            #     outputs_1 = torch.mean(outputs_1, 0)
            #     outputs_2 = outputs_2.view(5, -1, 7)
            #     outputs_2 = torch.mean(outputs_2, 0)
            # elif crop_type == 10:
            #     inputs = inputs.permute(1, 0, 2, 3, 4).contiguous()
            #     inputs = inputs.view(-1, 3, 224, 224)
            #     outputs_1, outputs_2 = model.forward(inputs)
            #     outputs_1 = outputs_1.view(10, -1, 7)
            #     outputs_1 = torch.mean(outputs_1, 0)
            #     outputs_2 = outputs_2.view(10, -1, 7)
            #     outputs_2 = torch.mean(outputs_2, 0)
            outputs_11, outputs_12, outputs_21, outputs_22 = model.forward(
                inputs)
            outputs_12 = outputs_12[sequence_length - 1::sequence_length]
            outputs_22 = outputs_22[sequence_length - 1::sequence_length]

            _, preds_12 = torch.max(outputs_12.data, 1)
            _, preds_22 = torch.max(outputs_22.data, 1)

            sig_out_11 = sig_f(outputs_11.data)
            sig_out_21 = sig_f(outputs_21.data)

            preds_11 = torch.ByteTensor(sig_out_11.cpu() > 0.5)
            preds_11 = preds_11.long()
            train_corrects_11 += torch.sum(preds_11 == labels_1.data.cpu())
            preds_21 = torch.ByteTensor(sig_out_21.cpu() > 0.5)
            preds_21 = preds_21.long()
            train_corrects_21 += torch.sum(preds_21 == labels_1.data.cpu())

            labels_1 = Variable(labels_1.data.float())
            loss_11 = criterion_1(outputs_11, labels_1)
            loss_21 = criterion_1(outputs_21, labels_1)

            loss_12 = criterion_2(outputs_12, labels_2)
            loss_22 = criterion_2(outputs_22, labels_2)

            val_loss_11 += loss_11.data[0]
            val_loss_12 += loss_12.data[0]
            val_loss_21 += loss_21.data[0]
            val_loss_22 += loss_22.data[0]
            val_corrects_12 += torch.sum(preds_12 == labels_2.data)
            val_corrects_22 += torch.sum(preds_22 == labels_2.data)

        val_elapsed_time = time.time() - val_start_time
        val_accuracy_11 = val_corrects_11 / num_val_all / 7
        val_accuracy_21 = val_corrects_21 / num_val_all / 7
        val_accuracy_12 = val_corrects_12 / num_val_we_use
        val_accuracy_22 = val_corrects_22 / num_val_we_use
        val_average_loss_11 = val_loss_11 / num_val_all / 7
        val_average_loss_21 = val_loss_21 / num_val_all / 7
        val_average_loss_12 = val_loss_12 / num_val_we_use
        val_average_loss_22 = val_loss_22 / num_val_we_use

        print('epoch: {:4d}'
              ' train time: {:2.0f}m{:2.0f}s'
              ' train accu_11: {:.4f}'
              ' train accu_21: {:.4f}'
              ' valid time: {:2.0f}m{:2.0f}s'
              ' valid accu_11: {:.4f}'
              ' valid accu_21: {:.4f}'.format(
                  epoch, train_elapsed_time // 60, train_elapsed_time % 60,
                  train_accuracy_11, train_accuracy_21, val_elapsed_time // 60,
                  val_elapsed_time % 60, val_accuracy_11, val_accuracy_21))
        print('epoch: {:4d}'
              ' train time: {:2.0f}m{:2.0f}s'
              ' train accu_12: {:.4f}'
              ' train accu_22: {:.4f}'
              ' valid time: {:2.0f}m{:2.0f}s'
              ' valid accu_12: {:.4f}'
              ' valid accu_22: {:.4f}'.format(
                  epoch, train_elapsed_time // 60, train_elapsed_time % 60,
                  train_accuracy_12, train_accuracy_22, val_elapsed_time // 60,
                  val_elapsed_time % 60, val_accuracy_12, val_accuracy_22))

        if optimizer_choice == 0:
            if sgd_adjust_lr == 0:
                exp_lr_scheduler.step()
            elif sgd_adjust_lr == 1:
                exp_lr_scheduler.step(val_average_loss_11 +
                                      val_average_loss_12 +
                                      val_average_loss_21 +
                                      val_average_loss_22)
示例#12
0
文件: GAN.py 项目: kristofe/BMSG-GAN
class MSG_GAN:
    """ Unconditional TeacherGAN

        args:
            depth: depth of the GAN (will be used for each generator and discriminator)
            latent_size: latent size of the manifold used by the GAN
            use_eql: whether to use the equalized learning rate
            use_ema: whether to use exponential moving averages.
            ema_decay: value of ema decay. Used only if use_ema is True
            device: device to run the GAN on (GPU / CPU)
    """

    def __init__(self, depth=7, latent_size=512,
                 use_eql=True, use_ema=True, ema_decay=0.999,
                 device=th.device("cpu")):
        """ constructor for the class """
        from torch.nn import DataParallel

        self.gen = Generator(depth, latent_size, use_eql=use_eql).to(device)

        # Parallelize them if required:
        if device == th.device("cuda"):
            self.gen = DataParallel(self.gen)
            self.dis = Discriminator(depth, latent_size,
                                     use_eql=use_eql, gpu_parallelize=True).to(device)
        else:
            self.dis = Discriminator(depth, latent_size, use_eql=True).to(device)

        # state of the object
        self.use_ema = use_ema
        self.ema_decay = ema_decay
        self.use_eql = use_eql
        self.latent_size = latent_size
        self.depth = depth
        self.device = device

        if self.use_ema:
            from MSG_GAN.CustomLayers import update_average

            # create a shadow copy of the generator
            self.gen_shadow = copy.deepcopy(self.gen)

            # updater function:
            self.ema_updater = update_average

            # initialize the gen_shadow weights equal to the
            # weights of gen
            self.ema_updater(self.gen_shadow, self.gen, beta=0)

        # by default the generator and discriminator are in eval mode
        self.gen.eval()
        self.dis.eval()
        if self.use_ema:
            self.gen_shadow.eval()

    def generate_samples(self, num_samples):
        """
        generate samples using this gan
        :param num_samples: number of samples to be generated
        :return: generated samples tensor: list[ Tensor(B x H x W x C)]
        """
        noise = th.randn(num_samples, self.latent_size).to(self.device)
        generated_images = self.gen(noise)

        # reshape the generated images
        generated_images = list(map(lambda x: (x.detach().permute(0, 2, 3, 1) / 2) + 0.5,
                                    generated_images))

        return generated_images

    def optimize_discriminator(self, dis_optim, noise, real_batch, loss_fn):
        """
        performs one step of weight update on discriminator using the batch of data
        :param dis_optim: discriminator optimizer
        :param noise: input noise of sample generation
        :param real_batch: real samples batch
                           should contain a list of tensors at different scales
        :param loss_fn: loss function to be used (object of GANLoss)
        :return: current loss
        """

        # generate a batch of samples
        fake_samples = self.gen(noise)
        fake_samples = list(map(lambda x: x.detach(), fake_samples))

        loss = loss_fn.dis_loss(real_batch, fake_samples)

        # optimize discriminator
        dis_optim.zero_grad()
        loss.backward()
        dis_optim.step()

        return loss.item()

    def optimize_generator(self, gen_optim, noise, real_batch, loss_fn):
        """
        performs one step of weight update on generator using the batch of data
        :param gen_optim: generator optimizer
        :param noise: input noise of sample generation
        :param real_batch: real samples batch
                           should contain a list of tensors at different scales
        :param loss_fn: loss function to be used (object of GANLoss)
        :return: current loss
        """

        # generate a batch of samples
        fake_samples = self.gen(noise)

        loss = loss_fn.gen_loss(real_batch, fake_samples)

        # optimize discriminator
        gen_optim.zero_grad()
        loss.backward()
        gen_optim.step()

        # if self.use_ema is true, apply the moving average here:
        if self.use_ema:
            self.ema_updater(self.gen_shadow, self.gen, self.ema_decay)

        return loss.item()

    def create_grid(self, samples, img_files):
        """
        utility function to create a grid of GAN samples
        :param samples: generated samples for storing list[Tensors]
        :param img_files: list of names of files to write
        :return: None (saves multiple files)
        """
        from torchvision.utils import save_image
        from torch.nn.functional import interpolate
        from numpy import sqrt, power

        # dynamically adjust the colour of the images
        samples = [Generator.adjust_dynamic_range(sample) for sample in samples]

        # resize the samples to have same resolution:
        for i in range(len(samples)):
            samples[i] = interpolate(samples[i],
                                     scale_factor=power(2,
                                                        self.depth - 1 - i))
        # save the images:
        for sample, img_file in zip(samples, img_files):
            save_image(sample, img_file, nrow=int(sqrt(sample.shape[0])),
                       normalize=True, scale_each=True, padding=0)

    def train(self, data, gen_optim, dis_optim, loss_fn, normalize_latents=True,
              start=1, num_epochs=12, feedback_factor=10, checkpoint_factor=1,
              data_percentage=100, num_samples=36,
              log_dir=None, sample_dir="./samples",
              save_dir="./models"):
        """
        Method for training the network
        :param data: pytorch dataloader which iterates over images
        :param gen_optim: Optimizer for generator.
                          please wrap this inside a Scheduler if you want to
        :param dis_optim: Optimizer for discriminator.
                          please wrap this inside a Scheduler if you want to
        :param loss_fn: Object of GANLoss
        :param normalize_latents: whether to normalize the latent vectors during training
        :param start: starting epoch number
        :param num_epochs: total number of epochs to run for (ending epoch number)
                           note this is absolute and not relative to start
        :param feedback_factor: number of logs generated and samples generated
                                during training per epoch
        :param checkpoint_factor: save model after these many epochs
        :param data_percentage: amount of data to be used
        :param num_samples: number of samples to be drawn for feedback grid
        :param log_dir: path to directory for saving the loss.log file
        :param sample_dir: path to directory for saving generated samples' grids
        :param save_dir: path to directory for saving the trained models
        :return: None (writes multiple files to disk)
        """

        from torch.nn.functional import avg_pool2d

        # turn the generator and discriminator into train mode
        self.gen.train()
        self.dis.train()

        assert isinstance(gen_optim, th.optim.Optimizer), \
            "gen_optim is not an Optimizer"
        assert isinstance(dis_optim, th.optim.Optimizer), \
            "dis_optim is not an Optimizer"

        print("Starting the training process ... ")

        # create fixed_input for debugging
        fixed_input = th.randn(num_samples, self.latent_size).to(self.device)
        if normalize_latents:
            fixed_input = (fixed_input
                           / fixed_input.norm(dim=-1, keepdim=True)
                           * (self.latent_size ** 0.5))

        # create a global time counter
        global_time = time.time()
        global_step = 0

        for epoch in range(start, num_epochs + 1):
            start_time = timeit.default_timer()  # record time at the start of epoch

            print("\nEpoch: %d" % epoch)
            total_batches = len(iter(data))

            limit = int((data_percentage / 100) * total_batches)

            for (i, batch) in enumerate(data, 1):

                # extract current batch of data for training
                images = batch.to(self.device)
                extracted_batch_size = images.shape[0]

                # create a list of downsampled images from the real images:
                images = [images] + [avg_pool2d(images, int(np.power(2, i)))
                                     for i in range(1, self.depth)]
                images = list(reversed(images))

                # sample some random latent points
                gan_input = th.randn(
                    extracted_batch_size, self.latent_size).to(self.device)

                # normalize them if asked
                if normalize_latents:
                    gan_input = (gan_input
                                 / gan_input.norm(dim=-1, keepdim=True)
                                 * (self.latent_size ** 0.5))

                # optimize the discriminator:
                dis_loss = self.optimize_discriminator(dis_optim, gan_input,
                                                       images, loss_fn)

                # optimize the generator:
                gen_loss = self.optimize_generator(gen_optim, gan_input,
                                                   images, loss_fn)

                # provide a loss feedback
                if i % (int(limit / feedback_factor) + 1) == 0 or i == 1:     # Avoid div by 0 error on small training sets
                    elapsed = time.time() - global_time
                    elapsed = str(datetime.timedelta(seconds=elapsed))
                    print("Elapsed [%s] batch: %d  d_loss: %f  g_loss: %f"
                          % (elapsed, i, dis_loss, gen_loss))

                    # also write the losses to the log file:
                    if log_dir is not None:
                        log_file = os.path.join(log_dir, "loss.log")
                        os.makedirs(os.path.dirname(log_file), exist_ok=True)
                        with open(log_file, "a") as log:
                            log.write(str(global_step) + "\t" + str(dis_loss) +
                                      "\t" + str(gen_loss) + "\n")

                    # create a grid of samples and save it
                    reses = [str(int(np.power(2, dep))) + "_x_"
                             + str(int(np.power(2, dep)))
                             for dep in range(2, self.depth + 2)]
                    gen_img_files = [os.path.join(sample_dir, res, "gen_" +
                                                  str(epoch) + "_" +
                                                  str(i) + ".png")
                                     for res in reses]

                    # Make sure all the required directories exist
                    # otherwise make them
                    os.makedirs(sample_dir, exist_ok=True)
                    for gen_img_file in gen_img_files:
                        os.makedirs(os.path.dirname(gen_img_file), exist_ok=True)

                    dis_optim.zero_grad()
                    gen_optim.zero_grad()
                    with th.no_grad():
                        self.create_grid(
                            self.gen(fixed_input) if not self.use_ema
                            else self.gen_shadow(fixed_input),
                            gen_img_files)

                # increment the global_step:
                global_step += 1

                if i > limit:
                    break

            # calculate the time required for the epoch
            stop_time = timeit.default_timer()
            print("Time taken for epoch: %.3f secs" % (stop_time - start_time))

            if epoch % checkpoint_factor == 0 or epoch == 1 or epoch == num_epochs:
                os.makedirs(save_dir, exist_ok=True)
                gen_save_file = os.path.join(save_dir, "GAN_GEN_" + str(epoch) + ".pth")
                dis_save_file = os.path.join(save_dir, "GAN_DIS_" + str(epoch) + ".pth")
                gen_optim_save_file = os.path.join(save_dir,
                                                   "GAN_GEN_OPTIM_" + str(epoch) + ".pth")
                dis_optim_save_file = os.path.join(save_dir,
                                                   "GAN_DIS_OPTIM_" + str(epoch) + ".pth")

                th.save(self.gen.state_dict(), gen_save_file)
                th.save(self.dis.state_dict(), dis_save_file)
                th.save(gen_optim.state_dict(), gen_optim_save_file)
                th.save(dis_optim.state_dict(), dis_optim_save_file)

                if self.use_ema:
                    gen_shadow_save_file = os.path.join(save_dir, "GAN_GEN_SHADOW_"
                                                        + str(epoch) + ".pth")
                    th.save(self.gen_shadow.state_dict(), gen_shadow_save_file)

        print("Training completed ...")

        # return the generator and discriminator back to eval mode
        self.gen.eval()
        self.dis.eval()
示例#13
0
    def train_model(self, epochs=30,n=None):
        dsets,dset_loaders,dset_sizes = self.transform() 
        print(dset_sizes)
        
        model = self.model
        #op = self.model_optimizer
        epoch_init = 0
        
        
        if(self.resume):
            
            try:
                self.load_model(filename=self.model_path_continue)
            except:
                print('invalid directory, starting from scratch')
        
        
        best_model=model
        
        criterion = self.criterion
        
        
        if(torch.cuda.is_available() and self.use_gpu and self.gpu in range(0,torch.cuda.device_count())):
            #torch.cuda.set_device(self.gpu)
            model=model.cuda()
            #criterion=criterion.cuda()
        
        elif(torch.cuda.is_available() and self.use_gpu and  len(self.gpu)>1):
            
            model = DataParallel(model,device_ids = self.gpu).cuda()
            #print('here')
            #criterion = DataParallel(model,device_ids=self.gpu)
        
        best_acc = 0.0
        best_epoch = 0
        
        
        if(~self.fe):
        
            for epoch in range(epoch_init,epochs):
                #print('Epoch = ',epoch)
                
                for phase in ['train','val']:
                    if(phase == 'train'):
                        model.train(True)

                        self.lr_scheduler(epoch)
                    else:
                        model.train(False)
                    c_mat = np.zeros((self.num_output,self.num_output)) 
                    running_loss = 0.0
                    running_corrects = 0.0
                    running_tp = 0.0
                    for data in dset_loaders[phase]:
                        inputs,labels = data
                        #print(inputs.size())
                        if(torch.cuda.is_available() and self.use_gpu):
                            inputs,labels = Variable(inputs.cuda(async=True)),Variable(labels.cuda(async=True))
                        else:
                            inputs,labels = Variable(inputs),Variable(labels)
                        
                        self.model_optimizer.zero_grad()
                        flag=0
                        
                        if(inputs.size(0)<self.b_size and  n == 'Inception'):
                            flag=1
                            
                            if(torch.cuda.is_available() and self.use_gpu):   
                                temp = Variable(torch.zeros((self.b_size,3,300,300)).cuda(async=True))
                                temp2 = Variable(torch.LongTensor(self.b_size).cuda(async=True))
                                             
                            else:
                                temp=Variable(torch.zeros((self.b_size,3,300,300)))
                                temp2=Variable(torch.LongTensor(self.b_size))
                                             
                            temp[0:inputs.size(0)]=inputs
                            inputs = temp
                            del(temp)
                            temp2[0:labels.size(0)] = labels
                            temp2[labels.size(0):] = 0
                            labels = temp2
                            del(temp2)
                        #print(epoch)
                        outputs = model(inputs)
                        #print(outputs.size())
                        #print(labels.size())
                        if(n=='Inception'):
                            if phase=='val':
                                #print('val')
                                _,preds = torch.max(outputs.data,1)
                                loss = criterion(outputs,labels)
                            else:
                                _,preds = torch.max(outputs[0].data,1)
                                loss = criterion(outputs[0],labels)

                        else:
                            _,preds = torch.max(outputs.data,1)
                            loss = criterion(outputs,labels)
                            


                        if phase=='train':
                            loss.backward()
                            self.model_optimizer.step()
                        
                        running_loss+=loss.data[0]
                        running_corrects += torch.sum(preds == labels.data)
                        for i in range(0,labels.data.cpu().numpy().shape[0]):

                            c_mat[labels.data.cpu().numpy()[i],preds.cpu().numpy()[i]]+=1

                        self.epochs = epoch
                        del(inputs)
                        del(labels)
                        del(outputs)
                   
                    epoch_loss = running_loss/dset_sizes[phase]
                    #self.plot(epoch_loss,'epoch_loss_'+phase,epoch)
                    epoch_acc = running_corrects/dset_sizes[phase]
                    #self.plot(epoch_acc,'epoch_acc_'+phase,epoch)
                    #epoch_tpr = running_tp/dset_sizes[phase]
                    print(phase + '{} Loss: {:.10f} \nAcc: {:.4f}'.format(phase,epoch_loss,epoch_acc))
                    #print(c_mat)
                    if(self.verbose):
                    
                        print(c_mat)
                    if phase=='val' and epoch_acc>best_acc:
                        best_acc=epoch_acc
                        best_model=copy.deepcopy(model)
                        best_epoch=epoch
                    #del(inp)
                    #del(label)
                #print()
        
        print(best_acc)
        print(best_epoch)
            
        self.model=best_model.cpu()
示例#14
0
def train(args):
    # gpu init
    multi_gpus = False
    if len(args.gpus.split(',')) > 1:
        multi_gpus = True
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpus
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # log init
    save_dir = os.path.join(args.save_dir, datetime.now().date().strftime('%Y%m%d'))
    if not os.path.exists(save_dir):
        #raise NameError('model dir exists!')
        os.makedirs(save_dir)
    logging = init_log(save_dir)
    _print = logging.info
    # summary(net.to(config.device), (3,112,112))
    #define tranform
    transform = transforms.Compose([
        transforms.Resize((112, 112)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),  # range [0, 255] -> [0.0,1.0]
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))  # range [0.0, 1.0] -> [-1.0,1.0]
    ])
    net = EfficientNet.from_name('efficientnet-b0', num_classes=2)

    # validation dataset
    trainset = ANTI(train_root="/mnt/sda3/data/FASD", file_list = "train.txt", transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size = 2,
                                             shuffle=True, num_workers=8, drop_last=False)

    # define optimizers for different layer
    criterion = torch.nn.CrossEntropyLoss().to(device)
    optimizer_ft = optim.SGD([
        {'params': net.parameters(), 'weight_decay': 5e-4},
    ], lr=0.001, momentum=0.9, nesterov=True)

    exp_lr_scheduler = lr_scheduler.MultiStepLR(optimizer_ft, milestones= [6, 10, 30], gamma=0.1)
    if multi_gpus:
        net = DataParallel(net).to(device)
    else:
        net = net.to(device)

    total_iters = 1
    vis = Visualizer(env= "effiction")

    for epoch in range(1, args.total_epoch + 1):
        exp_lr_scheduler.step()
        _print('Train Epoch: {}/{} ...'.format(epoch, args.total_epoch))
        net.train()
        since = time.time()
        for data in trainloader:
            img, label = data[0].to(device), data[1].to(device)
            optimizer_ft.zero_grad()
            raw_logits = net(img)
            total_loss = criterion(raw_logits, label)
            total_loss.backward()
            optimizer_ft.step()
            # print train information
            if total_iters % 200 == 0:
                # current training accuracy
                _, predict = torch.max(raw_logits.data, 1)
                total = label.size(0)
                correct = (np.array(predict) == np.array(label.data)).sum()
                time_cur = (time.time() - since) / 100
                since = time.time()
                vis.plot_curves({'softmax loss': total_loss.item()}, iters=total_iters, title='train loss',
                                xlabel='iters', ylabel='train loss')
                vis.plot_curves({'train accuracy': correct / total}, iters=total_iters, title='train accuracy', xlabel='iters',
                                ylabel='train accuracy')

                print("Iters: {:0>6d}/[{:0>2d}], loss: {:.4f}, train_accuracy: {:.4f}, time: {:.2f} s/iter, learning rate: {}".format(total_iters, epoch, total_loss.item(), correct/total, time_cur, exp_lr_scheduler.get_lr()[0]))

            # save model
            if total_iters % args.save_freq == 0:
                msg = 'Saving checkpoint: {}'.format(total_iters)
                _print(msg)
                if multi_gpus:
                    net_state_dict = net.module.state_dict()
                else:
                    net_state_dict = net.state_dict()
                   
                if not os.path.exists(save_dir):
                    os.mkdir(save_dir)

                torch.save({
                    'iters': total_iters,
                    'net_state_dict': net_state_dict},
                    os.path.join(save_dir, 'Iter_%06d_net.ckpt' % total_iters))

            # test accuracy
            if total_iters % args.test_freq == 0 and args.has_test:
                # test model on lfw
                net.eval()
                _print('LFW Ave Accuracy: {:.4f}'.format(np.mean(lfw_accs) * 100))

                net.train()
            total_iters += 1
    print('finishing training')
示例#15
0
net = net.cuda()
ArcMargin = ArcMargin.cuda()
if multi_gpus:
    net = DataParallel(net)
    ArcMargin = DataParallel(ArcMargin)
criterion = torch.nn.CrossEntropyLoss()


best_acc = 0.0
best_epoch = 0
for epoch in range(start_epoch, TOTAL_EPOCH+1):
    exp_lr_scheduler.step()
    # train model
    _print('Train Epoch: {}/{} ...'.format(epoch, TOTAL_EPOCH))
    net.train()

    train_total_loss = 0.0
    total = 0
    since = time.time()
    for data in trainloader:
        img, label = data[0].cuda(), data[1].cuda()
        batch_size = img.size(0)
        optimizer_ft.zero_grad()

        raw_logits = net(img)

        output = ArcMargin(raw_logits, label)
        total_loss = criterion(output, label)
        total_loss.backward()
        optimizer_ft.step()
示例#16
0
def train_model(train_dataset, train_num_each, val_dataset, val_num_each):
    num_train = len(train_dataset)
    num_val = len(val_dataset)

    train_useful_start_idx = get_useful_start_idx(sequence_length,
                                                  train_num_each)

    val_useful_start_idx = get_useful_start_idx(sequence_length, val_num_each)

    num_train_we_use = len(train_useful_start_idx) // num_gpu * num_gpu
    num_val_we_use = len(val_useful_start_idx) // num_gpu * num_gpu
    # num_train_we_use = 800
    # num_val_we_use = 80

    train_we_use_start_idx = train_useful_start_idx[0:num_train_we_use]
    val_we_use_start_idx = val_useful_start_idx[0:num_val_we_use]

    train_idx = []
    for i in range(num_train_we_use):
        for j in range(sequence_length):
            train_idx.append(train_we_use_start_idx[i] + j)

    val_idx = []
    for i in range(num_val_we_use):
        for j in range(sequence_length):
            val_idx.append(val_we_use_start_idx[i] + j)

    num_train_all = len(train_idx)
    num_val_all = len(val_idx)

    print('num train start idx : {:6d}'.format(len(train_useful_start_idx)))
    print('last idx train start: {:6d}'.format(train_useful_start_idx[-1]))
    print('num of train dataset: {:6d}'.format(num_train))
    print('num of train we use : {:6d}'.format(num_train_we_use))
    print('num of all train use: {:6d}'.format(num_train_all))
    print('num valid start idx : {:6d}'.format(len(val_useful_start_idx)))
    print('last idx valid start: {:6d}'.format(val_useful_start_idx[-1]))
    print('num of valid dataset: {:6d}'.format(num_val))
    print('num of valid we use : {:6d}'.format(num_val_we_use))
    print('num of all valid use: {:6d}'.format(num_val_all))

    train_loader = DataLoader(train_dataset,
                              batch_size=train_batch_size,
                              sampler=train_idx,
                              num_workers=workers,
                              pin_memory=False)
    val_loader = DataLoader(val_dataset,
                            batch_size=val_batch_size,
                            sampler=val_idx,
                            num_workers=workers,
                            pin_memory=False)

    model_old = multi_lstm()
    model_old = DataParallel(model_old)
    model_old.load_state_dict(
        torch.load(
            "cnn_lstm_epoch_25_length_10_opt_1_mulopt_1_flip_0_crop_1_batch_400_train1_9997_train2_9982_val1_9744_val2_8876.pth"
        ))

    model = multi_lstm_p2t()
    model.share = model_old.module.share
    model.lstm = model_old.module.lstm
    model.fc = model_old.module.fc
    model.fc2 = model_old.module.fc2

    model = DataParallel(model)
    for param in model.module.fc_p2t.parameters():
        param.requires_grad = False
    model.module.fc_p2t.load_state_dict(
        torch.load(
            "fc_epoch_25_length_4_opt_1_mulopt_1_flip_0_crop_1_batch_800_train1_9951_train2_9713_val1_9686_val2_7867_p2t.pth"
        ))

    if use_gpu:
        model = model.cuda()
        model.module.fc_p2t = model.module.fc_p2t.cuda()

    criterion_1 = nn.BCEWithLogitsLoss(size_average=False)
    criterion_2 = nn.CrossEntropyLoss(size_average=False)
    criterion_3 = nn.KLDivLoss(size_average=False)
    sigmoid_cuda = nn.Sigmoid().cuda()

    if multi_optim == 0:
        if optimizer_choice == 0:
            optimizer = optim.SGD([{
                'params': model.module.share.parameters()
            }, {
                'params': model.module.lstm.parameters(),
            }, {
                'params': model.module.fc.parameters()
            }, {
                'params': model.module.fc2.parameters()
            }],
                                  lr=learning_rate,
                                  momentum=momentum,
                                  dampening=dampening,
                                  weight_decay=weight_decay,
                                  nesterov=use_nesterov)
            if sgd_adjust_lr == 0:
                exp_lr_scheduler = lr_scheduler.StepLR(optimizer,
                                                       step_size=sgd_step,
                                                       gamma=sgd_gamma)
            elif sgd_adjust_lr == 1:
                exp_lr_scheduler = lr_scheduler.ReduceLROnPlateau(
                    optimizer, 'min')
        elif optimizer_choice == 1:
            optimizer = optim.Adam([{
                'params': model.module.share.parameters()
            }, {
                'params': model.module.lstm.parameters(),
            }, {
                'params': model.module.fc.parameters()
            }, {
                'params': model.module.fc2.parameters()
            }],
                                   lr=learning_rate)
    elif multi_optim == 1:
        if optimizer_choice == 0:
            optimizer = optim.SGD([{
                'params': model.module.share.parameters()
            }, {
                'params': model.module.lstm.parameters(),
                'lr': learning_rate
            }, {
                'params': model.module.fc.parameters(),
                'lr': learning_rate
            }],
                                  lr=learning_rate / 10,
                                  momentum=momentum,
                                  dampening=dampening,
                                  weight_decay=weight_decay,
                                  nesterov=use_nesterov)
            if sgd_adjust_lr == 0:
                exp_lr_scheduler = lr_scheduler.StepLR(optimizer,
                                                       step_size=sgd_step,
                                                       gamma=sgd_gamma)
            elif sgd_adjust_lr == 1:
                exp_lr_scheduler = lr_scheduler.ReduceLROnPlateau(
                    optimizer, 'min')
        elif optimizer_choice == 1:
            optimizer = optim.Adam([{
                'params': model.module.share.parameters()
            }, {
                'params': model.module.lstm.parameters(),
                'lr': learning_rate
            }, {
                'params': model.module.fc.parameters(),
                'lr': learning_rate
            }, {
                'params': model.module.fc2.parameters(),
                'lr': learning_rate
            }],
                                   lr=learning_rate / 10)

    best_model_wts = copy.deepcopy(model.state_dict())
    best_val_accuracy_1 = 0.0
    best_val_accuracy_2 = 0.0
    correspond_train_acc_1 = 0.0
    correspond_train_acc_2 = 0.0

    # 要存储2个train的准确率 2个valid的准确率 3个train 3个loss的loss, 一共10个数据要记录
    record_np = np.zeros([epochs, 12])

    for epoch in range(epochs):
        # np.random.seed(epoch)
        np.random.shuffle(train_we_use_start_idx)
        train_idx = []
        for i in range(num_train_we_use):
            for j in range(sequence_length):
                train_idx.append(train_we_use_start_idx[i] + j)

        train_loader = DataLoader(train_dataset,
                                  batch_size=train_batch_size,
                                  sampler=train_idx,
                                  num_workers=workers,
                                  pin_memory=False)

        model.train()
        train_loss_1 = 0.0
        train_loss_2 = 0.0
        train_loss_3 = 0.0
        train_corrects_1 = 0
        train_corrects_2 = 0
        train_corrects_3 = 0

        train_start_time = time.time()
        for data in train_loader:
            inputs, labels_1, labels_2 = data
            if use_gpu:
                inputs = Variable(inputs.cuda())
                labels_1 = Variable(labels_1.cuda())
                labels_2 = Variable(labels_2.cuda())
            else:
                inputs = Variable(inputs)
                labels_1 = Variable(labels_1)
                labels_2 = Variable(labels_2)

            optimizer.zero_grad()

            outputs_1, outputs_2, outputs_3 = model.forward(inputs)

            _, preds_2 = torch.max(outputs_2.data, 1)
            train_corrects_2 += torch.sum(preds_2 == labels_2.data)

            sig_output_1 = sigmoid_cuda(outputs_1)
            sig_output_3 = sigmoid_cuda(outputs_3)

            sig_average = (sig_output_1.data + sig_output_3.data) / 2

            preds_1 = torch.cuda.ByteTensor(sig_output_1.data > 0.5)
            preds_1 = preds_1.long()
            train_corrects_1 += torch.sum(preds_1 == labels_1.data)

            preds_3 = torch.cuda.ByteTensor(sig_average > 0.5)
            preds_3 = preds_3.long()
            train_corrects_3 += torch.sum(preds_3 == labels_1.data)

            labels_1 = Variable(labels_1.data.float())
            loss_1 = criterion_1(outputs_1, labels_1)
            loss_2 = criterion_2(outputs_2, labels_2)

            sig_output_3 = Variable(sig_output_3.data, requires_grad=False)
            loss_3 = torch.abs(criterion_3(sig_output_1, sig_output_3))
            loss = loss_1 + loss_2 + loss_3 * alpha
            loss.backward()
            optimizer.step()

            train_loss_1 += loss_1.data[0]
            train_loss_2 += loss_2.data[0]
            train_loss_3 += loss_3.data[0]

        train_elapsed_time = time.time() - train_start_time
        train_accuracy_1 = train_corrects_1 / num_train_all / 7
        train_accuracy_2 = train_corrects_2 / num_train_all
        train_accuracy_3 = train_corrects_3 / num_train_all / 7
        train_average_loss_1 = train_loss_1 / num_train_all / 7
        train_average_loss_2 = train_loss_2 / num_train_all
        train_average_loss_3 = train_loss_3 / num_train_all

        # begin eval

        model.eval()
        val_loss_1 = 0.0
        val_loss_2 = 0.0
        val_loss_3 = 0.0
        val_corrects_1 = 0
        val_corrects_2 = 0
        val_corrects_3 = 0

        val_start_time = time.time()
        for data in val_loader:
            inputs, labels_1, labels_2 = data
            labels_2 = labels_2[(sequence_length - 1)::sequence_length]
            if use_gpu:
                inputs = Variable(inputs.cuda(), volatile=True)
                labels_1 = Variable(labels_1.cuda(), volatile=True)
                labels_2 = Variable(labels_2.cuda(), volatile=True)
            else:
                inputs = Variable(inputs, volatile=True)
                labels_1 = Variable(labels_1, volatile=True)
                labels_2 = Variable(labels_2, volatile=True)

            outputs_1, outputs_2, outputs_3 = model.forward(inputs)
            outputs_2 = outputs_2[(sequence_length - 1)::sequence_length]
            _, preds_2 = torch.max(outputs_2.data, 1)
            val_corrects_2 += torch.sum(preds_2 == labels_2.data)

            sig_output_1 = sigmoid_cuda(outputs_1)
            sig_output_3 = sigmoid_cuda(outputs_3)

            sig_average = (sig_output_1.data + sig_output_3.data) / 2

            preds_1 = torch.cuda.ByteTensor(sig_output_1.data > 0.5)
            preds_1 = preds_1.long()
            val_corrects_1 += torch.sum(preds_1 == labels_1.data)

            preds_3 = torch.cuda.ByteTensor(sig_average > 0.5)
            preds_3 = preds_3.long()
            val_corrects_3 += torch.sum(preds_3 == labels_1.data)

            labels_1 = Variable(labels_1.data.float())
            loss_1 = criterion_1(outputs_1, labels_1)
            loss_2 = criterion_2(outputs_2, labels_2)

            sig_output_3 = Variable(sig_output_3.data, requires_grad=False)
            loss_3 = torch.abs(criterion_3(sig_output_1, sig_output_3))

            val_loss_1 += loss_1.data[0]
            val_loss_2 += loss_2.data[0]
            val_loss_3 += loss_3.data[0]

        val_elapsed_time = time.time() - val_start_time
        val_accuracy_1 = val_corrects_1 / (num_val_all * 7)
        val_accuracy_2 = val_corrects_2 / num_val_we_use
        val_accuracy_3 = val_corrects_3 / (num_val_all * 7)
        val_average_loss_1 = val_loss_1 / (num_val_all * 7)
        val_average_loss_2 = val_loss_2 / num_val_we_use
        val_average_loss_3 = val_loss_3 / num_val_all

        print('epoch: {:3d}'
              ' train time: {:2.0f}m{:2.0f}s'
              ' train accu_1: {:.4f}'
              ' train accu_3: {:.4f}'
              ' train accu_2: {:.4f}'
              ' train loss_1: {:4.4f}'
              ' train loss_2: {:4.4f}'
              ' train loss_3: {:4.4f}'.format(
                  epoch, train_elapsed_time // 60, train_elapsed_time % 60,
                  train_accuracy_1, train_accuracy_3, train_accuracy_2,
                  train_average_loss_1, train_average_loss_2,
                  train_average_loss_3))
        print('epoch: {:3d}'
              ' valid time: {:2.0f}m{:2.0f}s'
              ' valid accu_1: {:.4f}'
              ' valid accu_3: {:.4f}'
              ' valid accu_2: {:.4f}'
              ' valid loss_1: {:4.4f}'
              ' valid loss_2: {:4.4f}'
              ' valid loss_3: {:4.4f}'.format(
                  epoch, val_elapsed_time // 60, val_elapsed_time % 60,
                  val_accuracy_1, val_accuracy_3, val_accuracy_2,
                  val_average_loss_1, val_average_loss_2, val_average_loss_3))

        if optimizer_choice == 0:
            if sgd_adjust_lr == 0:
                exp_lr_scheduler.step()
            elif sgd_adjust_lr == 1:
                exp_lr_scheduler.step(val_average_loss_1 + val_average_loss_2 +
                                      alpha * val_average_loss_3)

        if val_accuracy_2 > best_val_accuracy_2 and val_accuracy_1 > 0.95:
            best_val_accuracy_2 = val_accuracy_2
            best_val_accuracy_1 = val_accuracy_1
            correspond_train_acc_1 = train_accuracy_1
            correspond_train_acc_2 = train_accuracy_2
            best_model_wts = copy.deepcopy(model.state_dict())
        elif val_accuracy_2 == best_val_accuracy_2 and val_accuracy_1 > 0.95:
            if val_accuracy_1 > best_val_accuracy_1:
                correspond_train_acc_1 = train_accuracy_1
                correspond_train_acc_2 = train_accuracy_2
                best_model_wts = copy.deepcopy(model.state_dict())
            elif val_accuracy_1 == best_val_accuracy_1:
                if train_accuracy_2 > correspond_train_acc_2:
                    correspond_train_acc_2 = train_accuracy_2
                    correspond_train_acc_1 = train_accuracy_1
                    best_model_wts = copy.deepcopy(model.state_dict())
                elif train_accuracy_2 == correspond_train_acc_2:
                    if train_accuracy_1 > best_val_accuracy_1:
                        correspond_train_acc_1 = train_accuracy_1
                        best_model_wts = copy.deepcopy(model.state_dict())

        if val_accuracy_2 > 0.885:
            save_val_1 = int("{:4.0f}".format(val_accuracy_1 * 10000))
            save_val_2 = int("{:4.0f}".format(val_accuracy_2 * 10000))
            save_train_1 = int("{:4.0f}".format(train_accuracy_1 * 10000))
            save_train_2 = int("{:4.0f}".format(train_accuracy_2 * 10000))
            public_name = "cnn_lstm_p2t" \
                          + "_epoch_" + str(epochs) \
                          + "_length_" + str(sequence_length) \
                          + "_opt_" + str(optimizer_choice) \
                          + "_mulopt_" + str(multi_optim) \
                          + "_flip_" + str(use_flip) \
                          + "_crop_" + str(crop_type) \
                          + "_batch_" + str(train_batch_size) \
                          + "_train1_" + str(save_train_1) \
                          + "_train2_" + str(save_train_2) \
                          + "_val1_" + str(save_val_1) \
                          + "_val2_" + str(save_val_2)
            model_name = public_name + ".pth"
            torch.save(best_model_wts, model_name)

        record_np[epoch, 0] = train_accuracy_1
        record_np[epoch, 1] = train_accuracy_3
        record_np[epoch, 2] = train_accuracy_2
        record_np[epoch, 3] = train_average_loss_1
        record_np[epoch, 4] = train_average_loss_2
        record_np[epoch, 5] = train_average_loss_3

        record_np[epoch, 6] = val_accuracy_1
        record_np[epoch, 7] = val_accuracy_3
        record_np[epoch, 7] = val_accuracy_2
        record_np[epoch, 9] = val_average_loss_1
        record_np[epoch, 10] = val_average_loss_2
        record_np[epoch, 11] = val_average_loss_3

    print('best accuracy_1: {:.4f} cor train accu_1: {:.4f}'.format(
        best_val_accuracy_1, correspond_train_acc_1))
    print('best accuracy_2: {:.4f} cor train accu_2: {:.4f}'.format(
        best_val_accuracy_2, correspond_train_acc_2))

    # save_val_1 = int("{:4.0f}".format(best_val_accuracy_1 * 10000))
    # save_val_2 = int("{:4.0f}".format(best_val_accuracy_2 * 10000))
    # save_train_1 = int("{:4.0f}".format(correspond_train_acc_1 * 10000))
    # save_train_2 = int("{:4.0f}".format(correspond_train_acc_2 * 10000))
    # public_name = "cnn_lstm_p2t" \
    #               + "_epoch_" + str(epochs) \
    #               + "_length_" + str(sequence_length) \
    #               + "_opt_" + str(optimizer_choice) \
    #               + "_mulopt_" + str(multi_optim) \
    #               + "_flip_" + str(use_flip) \
    #               + "_crop_" + str(crop_type) \
    #               + "_batch_" + str(train_batch_size) \
    #               + "_train1_" + str(save_train_1) \
    #               + "_train2_" + str(save_train_2) \
    #               + "_val1_" + str(save_val_1) \
    #               + "_val2_" + str(save_val_2)
    # model_name = public_name + ".pth"
    # torch.save(best_model_wts, model_name)

    record_name = public_name + ".npy"
    np.save(record_name, record_np)
示例#17
0
def main(parser, logger):
    print('--> Preparing Dataset:')
    trainset = OmniglotDataset(mode='train', root=parser.dataset_root)
    trainloader = data.DataLoader(trainset, batch_size=100, shuffle=True, num_workers=0)
    valset = dataloader(parser, 'val')
    testset = dataloader(parser, 'test')
    print('--> Building Model:')
    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    model = Network.resnet18().to(device)
    model = DataParallel(model)
    metric = ArcMarginProduct(256, len(np.unique(trainset.y)), s=30, m=0.5).to(device)
    metric = DataParallel(metric)
    criterion = torch.nn.CrossEntropyLoss()
    print('--> Initializing Optimizer and Scheduler:')
    optimizer = torch.optim.Adam(
        [{'params':model.parameters(), 'weight_decay':5e-4},
         {'params':[metric.weight], 'weight_decay':5e-4}],
        lr=parser.learning_rate, weight_decay=0.0005)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer=optimizer,
                                                gamma=parser.lr_scheduler_gamma,
                                                step_size=parser.lr_scheduler_step)
    best_acc = 0
    best_state = model.state_dict()
    for epoch in range(parser.epochs):
        print('\nEpoch: %d' % epoch)
        # Training
        train_loss = 0
        train_acc = 0
        train_correct = 0
        train_total = 0
        model.train()
        for batch_index, (inputs, targets) in enumerate(trainloader):
            inputs, targets = inputs.to(device), targets.to(device).long()
            feature = model(inputs)
            output = metric(feature, targets)
            loss = criterion(output, targets)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
            _, predicted = output.max(1)
            train_total += targets.size(0)
            train_correct += predicted.eq(targets).sum().item()
        scheduler.step()
        train_acc = 100.*train_correct / train_total
        print('Training Loss: {} | Accuracy: {}'.format(train_loss/train_total, train_acc))
        # Validating
        val_correct = 0
        val_total = 0
        model.eval()
        for batch_index, (inputs, targets) in enumerate(valset):
            inputs = inputs.to(device)
            targets = targets.to(device)
            feature = model(inputs)
            correct = eval(input=feature, target=targets, n_support=parser.num_support_val)
            val_correct += correct
            val_total += parser.classes_per_it_val * parser.num_query_val
        val_acc = 100.*val_correct / val_total
        print('Validating Accuracy: {}'.format(val_acc))
        if val_acc > best_acc:
            best_acc = val_acc
            best_state = model.state_dict()
    test_correct = 0
    test_total = 0
    model.load_state_dict(best_state)
    for epoch in range(10):
        for batch_index, (inputs, targets) in enumerate(testset):
            inputs = inputs.to(device)
            targets = targets.to(device)
            feature = model(inputs)
            correct = eval(input=feature, target=targets, n_support=parser.num_support_val)
            test_correct += correct
            test_total += parser.classes_per_it_val * parser.num_query_val
    test_acc = 100. * test_correct / test_total
    print('Testing Accuracy: {}'.format(test_acc))
示例#18
0
class SMSG_GAN:
    """ Unconditional SMSG_GAN

        args:
            depth: depth of the GAN (will be used for each generator and discriminator)
            latent_size: latent size of the manifold used by the GAN
            device: device to run the GAN on (GPU / CPU)
    """

    def __init__(self, depth=7, latent_size=512, device=th.device("cpu")):
        """ constructor for the class """
        from torch.nn import DataParallel

        # Create the Generator and the Discriminator
        self.gen = Generator(depth, latent_size).to(device)
        self.dis = Discriminator(depth, latent_size).to(device)

        if device == th.device("cuda"):  # apply the data parallel if device is GPU
            self.gen = DataParallel(Generator(depth, latent_size))
            self.dis = DataParallel(Discriminator(depth, latent_size))

        # state of the object
        self.latent_size = latent_size
        self.depth = depth
        self.device = device

        # by default the generator and discriminator are in eval mode
        self.gen.eval()
        self.dis.eval()

    def generate_samples(self, num_samples):
        """
        generate samples using this gan
        :param num_samples: number of samples to be generated
        :return: generated samples tensors: list[ Tensor(B x H x W x C)]
                 generated attn-map tensors: list[ Tensor(B x H x W x H x W)]
        """
        noise = th.randn(num_samples, self.latent_size).to(self.device)
        generated_images, attention_maps = self.gen(noise)

        # reshape the generated images
        generated_images = list(map(lambda x: (x.detach().permute(0, 2, 3, 1) / 2) + 0.5,
                                    generated_images))

        attention_maps = list(map(lambda x: x.detach(), attention_maps))

        return generated_images, attention_maps

    def optimize_discriminator(self, dis_optim, noise, real_batch, loss_fn):
        """
        performs one step of weight update on discriminator using the batch of data
        :param dis_optim: discriminator optimizer
        :param noise: input noise of sample generation
        :param real_batch: real samples batch
                           should contain a list of tensors at different scales
        :param loss_fn: loss function to be used (object of GANLoss)
        :return: current loss
        """

        # generate a batch of samples
        fake_samples, _ = self.gen(noise)
        fake_samples = list(map(lambda x: x.detach(), fake_samples))

        loss = loss_fn.dis_loss(real_batch, fake_samples)

        # optimize discriminator
        dis_optim.zero_grad()
        loss.backward()
        dis_optim.step()

        return loss.item()

    def optimize_generator(self, gen_optim, noise, real_batch, loss_fn):
        """
        performs one step of weight update on generator using the batch of data
        :param gen_optim: generator optimizer
        :param noise: input noise of sample generation
        :param real_batch: real samples batch
                           should contain a list of tensors at different scales
        :param loss_fn: loss function to be used (object of GANLoss)
        :return: current loss
        """

        # generate a batch of samples
        fake_samples, _ = self.gen(noise)

        loss = loss_fn.gen_loss(real_batch, fake_samples)

        # optimize discriminator
        gen_optim.zero_grad()
        loss.backward()
        gen_optim.step()

        return loss.item()

    @staticmethod
    def create_grid(samples, img_files):
        """
        utility function to create a grid of GAN samples
        :param samples: generated samples for storing list[Tensors]
        :param img_files: list of names of files to write
        :return: None (saves multiple files)
        """
        from torchvision.utils import save_image
        from numpy import sqrt

        samples = list(map(lambda x: th.clamp((x.detach() / 2) + 0.5, min=0, max=1),
                           samples))

        # save the images:
        for sample, img_file in zip(samples, img_files):
            save_image(sample, img_file, nrow=int(sqrt(sample.shape[0])))

    def train(self, data, gen_optim, dis_optim, loss_fn,
              start=1, num_epochs=12, feedback_factor=10, checkpoint_factor=1,
              data_percentage=100, num_samples=64,
              log_dir=None, sample_dir="./samples",
              save_dir="./models"):
        """
        method to train the SMSG-GAN network
        :param data: object of pytorch dataloader which provides iterator to the data
        :param gen_optim: optimizer for the generator parameters
        :param dis_optim: optimizer for discriminator parameters
        :param loss_fn: object of GANLoss (defines the loss function)
        :param start: starting epoch number
        :param num_epochs: ending epoch number
        :param feedback_factor: number of samples (logs) generated per epoch
        :param checkpoint_factor: model saved after these many epochs
        :param data_percentage: amount of data to be used for training
        :param num_samples: number of samples in the generated sample grid
                            (preferably a perfect square number)
        :param log_dir: path to the directory for saving the loss.log file
        :param sample_dir: path to the directory for saving the generated samples
        :param save_dir: path to the directory for saving trained models
        :return: None (saves model on disk)
        """

        from torch.nn.functional import avg_pool2d

        # turn the generator and discriminator into train mode
        self.gen.train()
        self.dis.train()

        assert isinstance(gen_optim, th.optim.Optimizer), \
            "gen_optim is not an Optimizer"
        assert isinstance(dis_optim, th.optim.Optimizer), \
            "dis_optim is not an Optimizer"

        print("Starting the training process ... ")

        # create fixed_input for debugging
        fixed_input = th.randn(num_samples, self.latent_size).to(self.device)

        # create a global time counter
        global_time = time.time()

        for epoch in range(start, num_epochs + 1):
            start = timeit.default_timer()  # record time at the start of epoch

            print("\nEpoch: %d" % epoch)
            total_batches = len(iter(data))

            limit = int((data_percentage / 100) * total_batches)

            for (i, batch) in enumerate(data, 1):

                # extract current batch of data for training
                images = batch.to(self.device)
                extracted_batch_size = images.shape[0]

                # create a list of downsampled images from the real images:
                images = [images] + [avg_pool2d(images, int(np.power(2, i)))
                                     for i in range(1, self.depth)]
                images = list(reversed(images))

                gan_input = th.randn(
                    extracted_batch_size, self.latent_size).to(self.device)

                # optimize the discriminator:
                dis_loss = self.optimize_discriminator(dis_optim, gan_input,
                                                       images, loss_fn)

                # optimize the generator:
                # resample from the latent noise
                gan_input = th.randn(
                    extracted_batch_size, self.latent_size).to(self.device)
                gen_loss = self.optimize_generator(gen_optim, gan_input,
                                                   images, loss_fn)

                # provide a loss feedback
                if i % int(limit / feedback_factor) == 0 or i == 1:
                    elapsed = time.time() - global_time
                    elapsed = str(datetime.timedelta(seconds=elapsed))
                    print("\nElapsed [%s] batch: %d  d_loss: %f  g_loss: %f"
                          % (elapsed, i, dis_loss, gen_loss))
                    print("Generator_gammas:", self.gen.module.get_gammas())
                    print("Discriminator_gammas:", self.dis.module.get_gammas())

                    # also write the losses to the log file:
                    if log_dir is not None:
                        log_file = os.path.join(log_dir, "loss.log")
                        os.makedirs(os.path.dirname(log_file), exist_ok=True)
                        with open(log_file, "a") as log:
                            log.write(str(dis_loss) + "\t" + str(gen_loss) + "\n")

                    # create a grid of samples and save it
                    reses = [str(int(np.power(2, dep))) + "_x_"
                             + str(int(np.power(2, dep)))
                             for dep in range(2, self.depth + 2)]
                    gen_img_files = [os.path.join(sample_dir, res, "gen_" +
                                                  str(epoch) + "_" +
                                                  str(i) + ".png")
                                     for res in reses]

                    # Make sure all the required directories exist
                    # otherwise make them
                    os.makedirs(sample_dir, exist_ok=True)
                    for gen_img_file in gen_img_files:
                        os.makedirs(os.path.dirname(gen_img_file), exist_ok=True)

                    self.create_grid(self.gen(fixed_input)[0], gen_img_files)

                if i > limit:
                    break

            # calculate the time required for the epoch
            stop = timeit.default_timer()
            print("Time taken for epoch: %.3f secs" % (stop - start))

            if epoch % checkpoint_factor == 0 or epoch == 1 or epoch == num_epochs:
                os.makedirs(save_dir, exist_ok=True)
                gen_save_file = os.path.join(save_dir, "GAN_GEN_" + str(epoch) + ".pth")
                dis_save_file = os.path.join(save_dir, "GAN_DIS_" + str(epoch) + ".pth")

                th.save(self.gen.state_dict(), gen_save_file)
                th.save(self.dis.state_dict(), dis_save_file)

        print("Training completed ...")

        # return the generator and discriminator back to eval mode
        self.gen.eval()
        self.dis.eval()
示例#19
0
class Solver():
    def __init__(self, args, num_class):
        self.args = args
        self.num_class = num_class

        #load data
        train_data = LoadMRIData(args.data_dir, args.data_list, 'train', num_class, num_slices=args.num_slices, se_loss = args.se_loss, use_weight = args.use_weights, Encode3D=args.encode3D)
        self.train_loader = DataLoader(train_data, batch_size = args.batch_size, shuffle = True, num_workers = args.workers, pin_memory=True)
        
        test_data = LoadMRIData(args.data_dir, args.data_list, 'test', num_class, num_slices=args.num_slices, se_loss = False, Encode3D=args.encode3D)
        self.test_loader = DataLoader(test_data, batch_size = 1, shuffle = False, num_workers = args.workers, pin_memory=True)
        
        model = Backbone(num_class, args.num_slices)
        
        ####################################################################################
        #set optimizer for different training strategies
        if args.two_stages:
            optimizer = torch.optim.SGD([{'params': model.encode3D1.parameters()},
                                     {'params': model.encode3D2.parameters()},
                                     {'params': model.encode3D3.parameters()},
                                     {'params': model.encode3D4.parameters()},
                                     {'params': model.bottleneck3D.parameters()},
                                     {'params': model.decode4.parameters()},
                                     {'params': model.decode3.parameters()},
                                     {'params': model.decode2.parameters()},
                                     {'params': model.decode1.parameters()},
                                     {'params': model.encmodule.parameters()},
                                     {'params': model.conv6.parameters()},
                                     #new added parameters
                                     {'params': model.decode0.parameters(), 'lr': args.lr},
                                     {'params': model.conv7.parameters(), 'lr': args.lr},
                                     {'params': model.conv8.parameters(), 'lr': args.lr},
                                     ],
                                    lr=1e-7, momentum=0.9, weight_decay=args.weight_decay)
        else:
            optimizer = torch.optim.SGD(model.parameters(), lr=args.lr,
                                        momentum=0.9, weight_decay=args.weight_decay)
        ####################################################################################
                
        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr,
                                      args.epochs, len(self.train_loader))
        #define loss
        self.criterion = lm.CombinedLoss(se_loss = args.se_loss)
        
        self.model, self.optimizer = model, optimizer
        
         # Using cuda
        if args.cuda:
            self.model = DataParallel(self.model).cuda()
            self.criterion = self.criterion.cuda()
        
        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'" .format(args.resume))
            checkpoint = torch.load(args.resume)
            
            if args.resume_pretrain:
                args.start_epoch = checkpoint['epoch']
                self.optimizer.load_state_dict(checkpoint['optimizer'])
                self.best_pred = checkpoint['best_pred']
            
            model_dict = model.state_dict()
            # 1. filter out unnecessary keys
            pretrained_dict = {k: v for k, v in checkpoint['state_dict'].items() if k in model_dict}
            
            # 2. overwrite entries in the existing state dict
            model_dict.update(pretrained_dict)
            
            if args.cuda:
                self.model.module.load_state_dict(model_dict)
            else:
                self.model.load_state_dict(model_dict)
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))


    def train(self, epoch):
        train_loss = 0.0
        self.model.train()
        tbar = tqdm(self.train_loader)
        for i, sample_batched in enumerate(tbar):
            
            target = sample_batched['label'].type(torch.LongTensor)
            skull = sample_batched['skull'].type(torch.LongTensor)
            
            if self.args.cuda:
                target, skull = target.cuda(), skull.cuda()
            
            
            image_3D = sample_batched['image_stack'].type(torch.FloatTensor)
            image_3D = image_3D.cuda()
            
            se_gt = None
            if self.args.se_loss:
                se_gt = sample_batched['se_gt'].type(torch.FloatTensor)
                if self.args.cuda:
                    se_gt = se_gt.cuda()
            
            weights = None
            if self.args.use_weights:
                weights = sample_batched['weights'].type(torch.FloatTensor)
                if self.args.cuda:
                    weights = weights.cuda()

            self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()
            
            outputs = self.model(image_3D)
            loss = self.criterion(outputs, target, skull, se_gt, weight = weights)
            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
            
            del target, skull, image_3D
            
        print("==== Epoch [" + str(epoch) + " / " + str(self.args.epochs) + "] DONE ====")
        print('Loss: %.3f' % train_loss)
        
        if self.args.no_val:
            # save checkpoint every epoch
            torch.save({
                'epoch': epoch + 1,
                'state_dict': self.model.module.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'best_pred': self.best_pred,
            }, os.path.join(self.args.save_dir, 'checkpoint_%03d.pth.tar' % epoch))
            print("save model on epoch %d" % epoch)
            
        
    def validation(self, epoch):
        self.model.eval()
        tbar = tqdm(self.test_loader, desc='\r')
        volume_dice_score_list = []
        volume_iou_score_list = []
        batch_size = self.args.test_batch_size
        
        with torch.no_grad():
            for ind, sample_batched in enumerate(tbar):
                volume = sample_batched['image_3D'].type(torch.FloatTensor)
                labelmap = sample_batched['label_3D'].type(torch.LongTensor)
                volume = torch.squeeze(volume)
                labelmap = torch.squeeze(labelmap)
                sample_name = sample_batched['name']
                
                if self.args.cuda:
                    volume, labelmap = volume.cuda(), labelmap.cuda()
                
                z_ax, x_ax, y_ax = np.shape(volume)
                
                volume_prediction = []
                skull_prediction = []
                for i in range(0, len(volume), batch_size):

                    if i<=int(self.args.num_slices*2+1):
                        image_stack0 = volume[0:int(self.args.num_slices*2+1),:,:][None]
                        image_stack1 = volume[1:int(self.args.num_slices*2+2),:,:][None]
                    elif i >=z_ax-int(self.args.num_slices*2+1):
                        image_stack0 = volume[z_ax-int(self.args.num_slices*2+2):-1,:,:][None]
                        image_stack1 = volume[z_ax-int(self.args.num_slices*2+1):,:,:][None]
                    else:
                        image_stack0 = volume[i-self.args.num_slices:i+self.args.num_slices+1,:,:][None]
                        image_stack1 = volume[i-self.args.num_slices+1:i+self.args.num_slices+2,:,:][None]
                    
                    image_3D = torch.cat((image_stack0, image_stack1), dim =0)
                    
                    outputs = self.model(image_3D)
                    pred = outputs[0]
                    skull_pred = outputs[2]
                    
                    _, batch_output = torch.max(pred, dim=1)
                    _, skull_output = torch.max(skull_pred, dim=1)
                    volume_prediction.append(batch_output)
                    skull_prediction.append(skull_output)
                
                #volume and label are both CxHxW
                volume_prediction = torch.cat(volume_prediction)
                
                #dice and iou evaluation
                volume_dice_score, volume_iou_score= score_perclass(volume_prediction, labelmap, self.num_class)
                
                volume_dice_score = volume_dice_score.cpu().numpy()
                volume_dice_score_list.append(volume_dice_score)
                tbar.set_description('Validate Dice Score: %.3f' % (np.mean(volume_dice_score)))
                
                volume_iou_score = volume_iou_score.cpu().numpy()
                volume_iou_score_list.append(volume_iou_score)

                ####################save output for visualization##################################
                visual = False
                if visual:
                    savedir_pred = os.path.join(self.args.save_dir,'pred')
                    if not os.path.exists(savedir_pred):
                        os.makedirs(savedir_pred)
                    volume_prediction = volume_prediction.cpu().numpy().astype(np.uint8)
                    volume_prediction = np.transpose(volume_prediction, (1,2,0))
                    nib_pred = nib.Nifti1Image(volume_prediction, affine=np.eye(4))
                    nib.save(nib_pred, os.path.join(savedir_pred, sample_name[0]+'.nii.gz'))
                ####################save output for visualization##################################
                
            
            del volume_prediction
            
            dice_score_arr = np.asarray(volume_dice_score_list)
            iou_score_arr = np.asarray(volume_iou_score_list)
            
            ####################################save best model for dice###################
            if self.args.num_class is 139:
                label_list = np.array([4,  11,  23,  30,  31,  32,  35,  36,  37,  38,  39,  40,
                           41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,  55,
                           56,  57,  58,  59,  60,  61,  62,  63,  64,  69,  71,  72,  73,
                           75,  76, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 112,
                          113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
                          128, 129, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142,
                          143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
                          156, 157, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170,
                          171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183,
                          184, 185, 186, 187, 190, 191, 192, 193, 194, 195, 196, 197, 198,
                          199, 200, 201, 202, 203, 204, 205, 206, 207])
                total_idx = np.arange(0, len(label_list))
                ignore = np.array([42, 43, 64, 69])
                
                valid_idx = [i+1 for i in total_idx if label_list[i] not in ignore]
                valid_idx = [0] + valid_idx
                
                dice_socre_vali = dice_score_arr[:,valid_idx]
                iou_score_vali = iou_score_arr[:,valid_idx]
            else:
                dice_socre_vali = dice_score_arr
                iou_score_vali = iou_score_arr
            ####################################save best model for dice###################
            
            avg_dice_score = np.mean(dice_socre_vali)
            std_dice_score = np.std(dice_socre_vali)
            avg_iou_score = np.mean(iou_score_vali)
            std_iou_score = np.std(iou_score_vali)
            print('Validation:')
            print("Mean of dice score : " + str(avg_dice_score))
            print("Std of dice score : " + str(std_dice_score))
            print("Mean of iou score : " + str(avg_iou_score))
            print("Std of dice score : " + str(std_iou_score))
            
            if avg_dice_score>self.best_pred:
                np.save(self.args.save_dir + 'dice_score.npy', dice_score_arr)
                np.save(self.args.save_dir + 'iou_score.npy', iou_score_arr)
                self.best_pred = avg_dice_score
                torch.save({
                    'epoch': epoch + 1,
                    'state_dict': self.model.module.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, os.path.join(self.args.save_dir, 'checkpoint_%03d.pth.tar' % epoch))
                print("save model on epoch %d" % epoch)
示例#20
0
class Trainer(object):
    def __init__(self,
                 model,
                 optimizer,
                 configuration,
                 train_criterion,
                 train_dataloader,
                 val_dataloader,
                 val_criterion=None,
                 result_criterion=None,
                 **kwargs):

        self.config = configuration

        if torch.cuda.device_count() == 1:
            self.model = model
        else:
            print("Parallel data processing...")
            self.model = DataParallel(model)
        self.train_criterion = train_criterion

        self.best_model = None
        self.best_model_filename = osp.join(self.config.log_output_dir,
                                            self.config.best_model_name)

        if val_criterion is None:
            self.val_criterion = train_criterion
        else:
            self.val_criterion = val_criterion
        if result_criterion is None:
            print("result_criterion is None")
            self.result_criterion = self.val_criterion
        else:
            self.result_criterion = result_criterion

        self.optimizer = optimizer

        if self.config.tf:
            self.writer = SummaryWriter(log_dir=self.config.tf_dir)
            self.loss_win = 'loss_win'
            self.result_win = 'result_win'
            self.criterion_params_win = 'cparam_win'
            criterion_params = {
                k: v.data.cpu().numpy()[0]
                for k, v in self.train_criterion.named_parameters()
            }
            self.n_criterion_params = len(criterion_params)
        # set random seed
        torch.manual_seed(self.config.seed)
        if self.config.cuda:
            torch.cuda.manual_seed(self.config.seed)

        # initiate model with checkpoint
        self.start_epoch = int(1)
        if self.config["checkpoint"]:
            self.load_checkpoint()
        else:
            print("No checkpoint file")
        print('start_epoch = {}'.format(self.start_epoch))

        self.train_dataloader = train_dataloader
        self.val_dataloader = val_dataloader

        self.pose_m, self.pose_s = np.loadtxt(self.config.pose_stats_file)
        self.pose_m = Variable(torch.from_numpy(self.pose_m).float(),
                               requires_grad=False).cuda(async=True)
        self.pose_s = Variable(torch.from_numpy(self.pose_s).float(),
                               requires_grad=False).cuda(async=True)

        if self.config.cuda:
            self.model.cuda()
            self.train_criterion.cuda()
            self.val_criterion.cuda()

    def run(self):
        n_epochs = self.config.n_epochs
        for epoch in xrange(self.start_epoch, n_epochs + 1):
            # validate
            val_loss = None
            if self.config.do_val and ((epoch % self.config.val_freq == 0) or
                                       (epoch == n_epochs - 1)):
                val_loss = self.validate(epoch)
                if self.best_model is None or self.best_model[
                        'loss'] is None or val_loss < self.best_model['loss']:
                    self.best_model = self.pack_checkpoint(epoch, val_loss)
                    torch.save(
                        self.best_model,
                        osp.join(self.config.checkpoint_dir,
                                 self.config.best_model_name))
                    print("Best model saved at epoch {:d} in name of {:s}".
                          format(epoch, self.config.best_model_name))

            # save checkpoint
            if epoch % self.config.snapshot == 0:
                checkpoint = self.pack_checkpoint(epoch=epoch, loss=val_loss)
                fn = osp.join(self.config.checkpoint_dir,
                              "epoch_{:04d}.pth.tar".format(epoch))
                torch.save(checkpoint, fn)
                print('Epoch {:d} checkpoint saved: {:s}'.format(epoch, fn))

            self.train(epoch)

    def get_result_loss(self, output, target):
        target_var = Variable(target, requires_grad=False).cuda(async=True)
        t_loss, q_loss = self.result_criterion(output, target_var, self.pose_m,
                                               self.pose_s)

        return t_loss, q_loss

    def train(self, epoch):
        self.model.train()
        train_data_time = Timer()
        train_batch_time = Timer()
        train_data_time.tic()
        for batch_idx, (data, target) in enumerate(self.train_dataloader):
            train_data_time.toc()

            train_batch_time.tic()
            loss, output = self.step_feedfwd(data,
                                             self.model,
                                             target=target,
                                             criterion=self.train_criterion,
                                             optim=self.optimizer,
                                             train=True)

            t_loss, q_loss = self.get_result_loss(output, target)
            train_batch_time.toc()

            if batch_idx % self.config.print_freq == 0:
                n_itr = (epoch - 1) * len(self.train_dataloader) + batch_idx
                epoch_count = float(n_itr) / len(self.train_dataloader)
                print(
                    'Train {:s}: Epoch {:d}\t'
                    'Batch {:d}/{:d}\t'
                    'Data time {:.4f} ({:.4f})\t'
                    'Batch time {:.4f} ({:.4f})\t'
                    'Loss {:f}'.format(self.config.experiment, epoch,
                                       batch_idx,
                                       len(self.train_dataloader) - 1,
                                       train_data_time.last_time(),
                                       train_data_time.avg_time(),
                                       train_batch_time.last_time(),
                                       train_batch_time.avg_time(), loss))
                if self.config.tf:
                    self.writer.add_scalars(self.loss_win,
                                            {"training_loss": loss}, n_itr)
                    self.writer.add_scalars(
                        self.result_win, {
                            "training_t_loss": t_loss.item(),
                            "training_q_loss": q_loss.item()
                        }, n_itr)
                    if self.n_criterion_params:
                        for name, v in self.train_criterion.named_parameters():
                            v = v.data.cpu().numpy()[0]
                            self.writer.add_scalars(self.criterion_params_win,
                                                    {name: v}, n_itr)

            train_data_time.tic()

    def validate(self, epoch):
        # if self.visualize_val_err:
        #     L = len(self.val_dataloader)
        #     # print("L={}".format(L))
        #     batch_size = 10
        #     pred_pose = np.zeros((L * batch_size, 7))
        #     targ_pose = np.zeros((L * batch_size, 7))

        val_batch_time = Timer()  # time for step in each batch
        val_loss = AverageMeter()
        t_loss = AverageMeter()
        q_loss = AverageMeter()
        self.model.eval()
        val_data_time = Timer()  # time for data retrieving
        val_data_time.tic()
        for batch_idx, (data, target) in enumerate(self.val_dataloader):
            val_data_time.toc()

            val_batch_time.tic()
            loss, output = self.step_feedfwd(
                data,
                self.model,
                target=target,
                criterion=self.val_criterion,
                optim=self.optimizer,  # what will optimizer do in validation?
                train=False)
            # NxTx7
            val_batch_time.toc()
            val_loss.update(loss)

            t_loss_batch, q_loss_batch = self.get_result_loss(output, target)
            t_loss.update(t_loss_batch.item())
            q_loss.update(q_loss_batch.item())

            if batch_idx % self.config.print_freq == 0:
                print(
                    'Val {:s}: Epoch {:d}\t'
                    'Batch {:d}/{:d}\t'
                    'Data time {:.4f} ({:.4f})\t'
                    'Batch time {:.4f} ({:.4f})\t'
                    'Loss {:f}'.format(self.config.experiment, epoch,
                                       batch_idx,
                                       len(self.val_dataloader) - 1,
                                       val_data_time.last_time(),
                                       val_data_time.avg_time(),
                                       val_batch_time.last_time(),
                                       val_batch_time.avg_time(), loss))

            val_data_time.tic()

        # pred_pose = pred_pose.view(-1, 7)
        # targ_pose = targ_pose.view(-1, 7)
        print('Val {:s}: Epoch {:d}, val_loss {:f}'.format(
            self.config.experiment, epoch, val_loss.average()))
        print 'Mean error in translation: {:3.2f} m\n' \
              'Mean error in rotation: {:3.2f} degree'.format(t_loss.average(), q_loss.average())

        if self.config.tf:
            n_itr = (epoch - 1) * len(self.train_dataloader)
            self.writer.add_scalars(self.loss_win,
                                    {"val_loss": val_loss.average()}, n_itr)
            self.writer.add_scalars(self.result_win, {
                "val_t_loss": t_loss.average(),
                "val_q_loss": q_loss.average()
            }, n_itr)
            # self.vis.line(
            # X=np.asarray([epoch]),
            # Y=np.asarray([val_loss.average()]),
            # win=self.loss_win,
            # name='val_loss',
            # # append=True,
            # update='append',
            # env=self.vis_env
            # )
            # self.vis.line(
            # X=np.asarray([epoch]),
            # Y=np.asarray([t_loss.average()]),
            # win=self.result_win,
            # name='val_t_loss',
            # update='append',
            # env=self.vis_env
            # )
            # self.vis.line(
            # X=np.asarray([epoch]),
            # Y=np.asarray([q_loss.average()]),
            # win=self.result_win,
            # name='val_q_loss',
            # update='append',
            # env=self.vis_env
            # )
            # self.vis.save(envs=[self.vis_env])

        return t_loss.average()

    def step_feedfwd(self,
                     data,
                     model,
                     target=None,
                     criterion=None,
                     train=True,
                     **kwargs):
        optim = kwargs["optim"]
        if train:
            assert criterion is not None
            data_var = Variable(data, requires_grad=True).cuda(async=True)
            target_var = Variable(target, requires_grad=False).cuda(async=True)
        else:
            data_var = Variable(data, requires_grad=False).cuda(async=True)
            target_var = Variable(target, requires_grad=False).cuda(async=True)

        output = model(data_var)

        if criterion is not None:
            loss = criterion(output, target_var)

            if train:
                optim.zero_grad()
                loss.backward()
                if self.config.max_grad_norm > 0.0:
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   self.config.max_grad_norm)
                optim.step()
            return loss.data[0], output
        else:
            return 0, output

        # Help functions
    def load_checkpoint(self):
        checkpoint_file = self.config.checkpoint
        resume_optim = self.config.resume_optim
        if osp.isfile(checkpoint_file):
            loc_func = None if self.config.cuda else lambda storage, loc: storage
            # map_location: specify how to remap storage
            checkpoint = torch.load(checkpoint_file, map_location=loc_func)
            self.best_model = checkpoint
            load_state_dict(self.model, checkpoint["model_state_dict"])

            self.start_epoch = checkpoint['epoch']

            # Is this meaningful !?
            if checkpoint.has_key('criterion_state_dict'):
                c_state = checkpoint['criterion_state_dict']
                # retrieve key in train_criterion
                append_dict = {
                    k: torch.Tensor([0, 0])
                    for k, _ in self.train_criterion.named_parameters()
                    if not k in c_state
                }
                # load zeros into state_dict
                c_state.update(append_dict)
                self.train_criterion.load_state_dict(c_state)

            print("Loaded checkpoint {:s} epoch {:d}".format(
                checkpoint_file, checkpoint['epoch']))
            print("Loss of loaded model = {}".format(checkpoint['loss']))

            if resume_optim:
                print("Load parameters in optimizer")
                self.optimizer.load_state_dict(checkpoint["optim_state_dict"])
                for state in self.optimizer.state.values():
                    for k, v in state.items():
                        if isinstance(v, torch.Tensor):
                            state[k] = v.cuda()

            else:
                print("Notice: load checkpoint but didn't load optimizer.")
        else:
            print("Can't find specified checkpoint.!")
            exit(-1)

    def pack_checkpoint(self, epoch, loss=None):
        checkpoint_dict = {
            'epoch': epoch,
            'model_state_dict': self.model.state_dict(),
            'optim_state_dict': self.optimizer.state_dict(),
            'criterion_state_dict': self.train_criterion.state_dict(),
            'loss': loss
        }
        # torch.save(checkpoint_dict, filename)
        return checkpoint_dict
示例#21
0
class TrainLoop_GPT2():
    def __init__(self, args, logger):
        self.args = args
        self.logger = logger

        self.args.device = 'cuda:{}'.format(
            self.args.gpu) if self.args.use_cuda else 'cpu'
        self.logger.info('using device:{}'.format(self.args.device))

        self.opt = vars(self.args)

        self.batch_size = self.opt['batch_size']
        self.use_cuda = self.opt['use_cuda']
        self.device = self.args.device
        self.multi_gpu = self.args.use_multi_gpu

        # self.movie_ids = pickle.load(open("data/movie_ids.pickle", "rb"))

        self.build_data()
        self.build_model()

    def build_data(self):
        self.tokenizer = BertTokenizer(vocab_file=self.args.vocab_path)
        self.vocab_size = len(self.tokenizer)
        self.pad_id = self.tokenizer.convert_tokens_to_ids('[PAD]')

        # 对原始数据进行预处理,将原始语料转换成对应的token_id
        if self.args.raw:
            for subset in ['train', 'valid', 'test']:
                self.preprocess_raw_data(subset)
        # 加载tokenized data
        self.subset2data = {}
        with open(self.args.test_tokenized_path, "r", encoding="utf8") as f:
            self.subset2data['test'] = f.read()
        if not self.args.do_eval:
            with open(self.args.train_tokenized_path, "r",
                      encoding="utf8") as f:
                self.subset2data['train'] = f.read()
            with open(self.args.valid_tokenized_path, "r",
                      encoding="utf8") as f:
                self.subset2data['valid'] = f.read()
        # 这一步是干啥的
        for subset in self.subset2data:
            self.subset2data[subset] = self.subset2data[subset].split("\n")

        self.logger.info("Train/Valid/Test set has {} convs".format(
            [len(self.subset2data[subset]) for subset in self.subset2data]))

    def build_model(self):
        """

        :param args:
        :param vocab_size:字典大小
        :return:
        """
        if self.args.pretrained_model:
            # 如果指定了预训练的GPT2模型
            self.model = GPT2LMHeadModel.from_pretrained(
                self.args.pretrained_model)
        else:
            # 若没有指定预训练模型,则初始化模型
            model_config = transformers.modeling_gpt2.GPT2Config.from_json_file(
                self.args.model_config)
            self.model = GPT2LMHeadModel(config=model_config)

        # 根据tokenizer的vocabulary调整GPT2模型的voca的大小
        self.model.resize_token_embeddings(self.vocab_size)

        if self.use_cuda:
            self.model.to(self.device)

        self.logger.info('model config:\n{}'.format(
            self.model.config.to_json_string()))

        self.n_ctx = self.model.config.to_dict().get("n_ctx")

        # 建立模型存储路径
        if self.args.is_model_output and not os.path.exists(
                self.args.dialogue_model_output_path):
            os.mkdir(self.args.dialogue_model_output_path)

        # 记录模型参数数量
        num_parameters = 0
        parameters = self.model.parameters()
        for parameter in parameters:
            num_parameters += parameter.numel()
        self.logger.info(
            'number of model parameters: {}'.format(num_parameters))

        # 是否使用多块GPU进行并行运算
        if self.args.use_multi_gpu:
            if self.args.use_cuda and torch.cuda.device_count() > 1:
                self.logger.info("Let's use GPUs to train")
                self.model = DataParallel(
                    self.model,
                    device_ids=[int(i) for i in self.args.device.split(',')])
            else:
                self.args.use_multi_gpu = False

    def train(self):
        train_dataset = GPT2Dataset(self.subset2data['train'])
        train_dataloader = DataLoader(train_dataset,
                                      batch_size=self.args.batch_size,
                                      shuffle=True,
                                      num_workers=self.args.num_workers,
                                      collate_fn=self.collate_fn)

        # 计算所有epoch进行参数优化的总步数total_steps
        self.total_steps = int(train_dataset.__len__() * self.args.epochs /
                               self.args.batch_size /
                               self.args.gradient_accumulation)
        self.logger.info('total training steps = {}'.format(self.total_steps))

        self.init_optim()

        self.logger.info('starting training')
        # 用于统计每次梯度累计的loss
        running_loss = 0
        # 统计一共训练了多少个step
        overall_step = 0
        # 记录tensorboardX
        # tb_writer = SummaryWriter(log_dir=self.args.writer_dir)
        # 记录 out of memory的次数
        oom_time = 0
        # patience
        patience = 0
        max_patience = 2
        best_test_loss = 10000
        # 开始训练
        for epoch in range(self.args.epochs):
            epoch_start_time = datetime.now()
            train_loss = []  # 记录一个epoch里面的train loss
            for batch_idx, (input_ids, mask_r) in enumerate(train_dataloader):
                # 注意:GPT2模型的forward()函数,是对于给定的context,生成一个token,而不是生成一串token
                # GPT2Model的输入为n个token_id时,输出也是n个hidden_state,使用第n个hidden_state预测第n+1个token
                # self.logger.info(input_ids == mask_r)
                # self.logger.info(input_ids)
                # self.logger.info(mask_r)
                # for context in input_ids:
                #     print(tokenizer.convert_ids_to_tokens(int(id) for id in context))
                # ipdb.set_trace()
                self.model.train()
                input_ids = input_ids.to(self.device)
                # 解决在运行过程中,由于显存不足产生的cuda out of memory的问题
                try:
                    outputs = self.model.forward(input_ids=input_ids)
                    loss, accuracy = self.calculate_loss_and_accuracy(
                        outputs, input_ids, mask_r, device=self.device)
                    train_loss.append(loss.item())

                    if self.multi_gpu:
                        loss = loss.mean()
                        accuracy = accuracy.mean()
                    if self.args.gradient_accumulation > 1:
                        loss = loss / self.args.gradient_accumulation
                        accuracy = accuracy / self.args.gradient_accumulation
                    loss.backward()
                    # 梯度裁剪解决的是梯度消失或爆炸的问题,即设定阈值
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                                   self.args.max_grad_norm)
                    # 进行一定step的梯度累计之后,更新参数
                    if (batch_idx + 1) % self.args.gradient_accumulation == 0:
                        running_loss += loss.item()
                        self.optimizer.step()
                        self.optimizer.zero_grad()
                        self.scheduler.step()

                        overall_step += 1
                        # 更新日志与tnesorboardX信息
                        if (overall_step + 1) % self.args.log_step == 0:
                            self.logger.info(
                                "batch {} of epoch {}, loss {:.4f}, ppl {:.5f}"
                                .format(batch_idx + 1, epoch + 1, loss,
                                        exp(loss)))
                            # tb_writer.add_scalar('loss', loss.item(), overall_step)
                except RuntimeError as exception:
                    if "out of memory" in str(exception):
                        oom_time += 1
                        self.logger.info(
                            "WARNING: ran out of memory,times: {}".format(
                                oom_time))
                        if hasattr(torch.cuda, 'empty_cache'):
                            torch.cuda.empty_cache()
                    else:
                        self.logger.info(str(exception))
                        raise exception
            train_loss = sum(train_loss) / len(train_loss)
            epoch_finish_time = datetime.now()
            self.logger.info(
                'epoch {}, train loss is {:.4f}, ppl is {:.5f}, spend {} time'.
                format(epoch + 1, train_loss, exp(train_loss),
                       epoch_finish_time - epoch_start_time))
            # val
            # test_loss = val(model, device, test_list, multi_gpu, self.args)
            test_loss = self.val('valid')
            if test_loss <= best_test_loss:
                patience = 0
                best_test_loss = test_loss

                self.logger.info('saving model for epoch {}'.format(epoch + 1))
                model_path = join(self.args.dialogue_model_output_path,
                                  'model')
                if not os.path.exists(model_path):
                    os.mkdir(model_path)
                # 这里是什么意思,还不是很懂
                model_to_save = self.model.module if hasattr(
                    self.model, 'module') else self.model
                model_to_save.save_pretrained(model_path)
                self.logger.info("save model to " + str(model_path))
            else:
                patience += 1
                self.logger.info('Patience = ' + str(patience))
                if patience >= max_patience:
                    break
            test_loss = self.val('test')

        # self.logger.info('training finished')

    def val(self, subset):
        # self.logger.info("start evaluating model")
        self.model.eval()
        # self.logger.info('starting evaluating')
        # 记录tensorboardX
        # tb_writer = SummaryWriter(log_dir=self.args.writer_dir)
        test_dataset = GPT2Dataset(self.subset2data[subset])
        test_dataloader = DataLoader(test_dataset,
                                     batch_size=self.args.batch_size,
                                     shuffle=True,
                                     num_workers=self.args.num_workers,
                                     collate_fn=self.collate_fn)
        test_loss = []
        # test_accuracy = []
        with torch.no_grad():
            for batch_idx, (input_ids, mask_r) in enumerate(test_dataloader):
                input_ids = input_ids.to(self.device)
                outputs = self.model.forward(input_ids=input_ids)
                loss, accuracy = self.calculate_loss_and_accuracy(
                    outputs, input_ids, mask_r, device=self.device)
                test_loss.append(loss.item())
                # test_accuracy.append(accuracy)
                if self.multi_gpu:
                    loss = loss.mean()
                    accuracy = accuracy.mean()
                if self.args.gradient_accumulation > 1:
                    loss = loss / self.args.gradient_accumulation
                    accuracy = accuracy / self.args.gradient_accumulation
                # self.logger.info("val batch {} ,loss {} ,accuracy {}".format(batch_idx, loss, accuracy))
                # tb_writer.add_scalar('loss', loss.item(), overall_step)
        test_loss = sum(test_loss) / len(test_loss)
        self.logger.info("val {} loss {:.4f} , ppl {:.5f}".format(
            subset, test_loss, exp(test_loss)))

        return test_loss

    def generate(self):
        samples_file = open(self.args.save_samples_path, 'w', encoding='utf8')
        convs = pickle.load(open(self.args.test_path, 'rb'))

        for conv in tqdm(convs[:]):
            conv_id = conv['conv_id']
            history = []  # list of id, to model

            for message in conv['messages']:
                message_id, role, content = int(
                    message['local_id']), message['role'], message['content']
                if role == 'Recommender' and message_id != 1:
                    try:
                        if self.args.save_samples_path:
                            samples_file.write(f"[GroundTruth]: {content}\n")
                        input_ids = [
                            self.tokenizer.cls_token_id
                        ] + history[-self.args.max_context_len +
                                    1:]  # 每个input以[CLS]为开头 [SEP]结尾
                        # tensor of [input_token_num]
                        curr_input_tensor = torch.tensor(input_ids).long().to(
                            self.device)
                        generated = []
                        # 最多生成max_len个token
                        for _ in range(self.args.max_len):
                            # (tensor of [input_token_nums, 13317], tuple of 10 tensor)
                            outputs = self.model(
                                input_ids=curr_input_tensor)  #?shape?
                            # tensor of [13317]
                            next_token_logits = outputs[0][-1, :]
                            # 对于已生成的结果generated中的每个token添加一个重复惩罚项,降低其生成概率
                            for id in set(generated):
                                next_token_logits[
                                    id] /= self.args.repetition_penalty
                            next_token_logits = next_token_logits / self.args.temperature
                            # 对于[UNK]的概率设为无穷小,也就是说模型的预测结果不可能是[UNK]这个token
                            next_token_logits[
                                self.tokenizer.convert_tokens_to_ids(
                                    '[UNK]')] = -float('Inf')
                            # 将topk以外的token的概率设置为-inf,然后排序,然后将accum-概率大与topp的token的概率设置为-inf
                            filtered_logits = top_k_top_p_filtering(
                                next_token_logits,
                                top_k=self.args.topk,
                                top_p=self.args.topp)
                            # torch.multinomial表示从候选集合中无放回地进行抽取num_samples个元素,权重越高,抽到的几率越高,返回元素的下标
                            next_token = torch.multinomial(F.softmax(
                                filtered_logits, dim=-1),
                                                           num_samples=1)
                            if next_token == self.tokenizer.sep_token_id:  # 遇到[SEP]则表明response生成结束
                                break
                            generated.append(next_token.item())
                            curr_input_tensor = torch.cat(
                                (curr_input_tensor, next_token),
                                dim=0)[-self.n_ctx:]
                        generated_text = self.tokenizer.convert_ids_to_tokens(
                            generated)
                        if self.args.save_samples_path:
                            samples_file.write("[Generated]: {}\n\n".format(
                                "".join(generated_text)))

                    except Exception as e:
                        print(e)
                        print(conv_id, message_id)
                        print(max(input_ids))
                        print('\n')
                history.extend(
                    self.tokenizer.encode(content) +
                    [self.tokenizer.sep_token_id])  #? encode成了啥

        samples_file.close()

    def calculate_loss_and_accuracy(self, outputs, labels, mask_r, device):
        """
        计算非self.pad_id的平均loss和准确率
        :param outputs:
        :param labels:
        :param device:
        :return:
        """
        logits = outputs[
            0]  # 每个token用来预测下一个token的prediction_score,维度:[batch_size,token_len,voca_size]
        # 用前n-1个token,预测出第n个token
        # 用第i个token的prediction_score用来预测第i+1个token。
        # 假定有input有n个token,则shift_logits表示model中第[0,n-2]个token的prediction_score,shift_labels表示第[1,n-1]的label
        shift_logits = logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous().to(device)
        ##################################### shift_labels给mask掉
        mask_shift_labels = mask_r[..., 1:].contiguous().to(device)
        shift_labels = shift_labels * mask_shift_labels
        #######################################

        loss_fct = CrossEntropyLoss(
            ignore_index=self.pad_id,
            reduction='sum')  # 忽略self.pad_id的loss,并对所有的非self.pad_id的loss进行求和
        loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
                        shift_labels.view(-1))

        _, preds = shift_logits.max(
            dim=-1
        )  # preds表示对应的prediction_score预测出的token在voca中的id。维度为[batch_size,token_len]

        # 对非self.pad_id的token的loss进行求平均,且计算出预测的准确率
        not_ignore = shift_labels.ne(
            self.pad_id
        )  # 进行非运算,返回一个tensor,若targets_view的第i个位置为self.pad_id,则置为0,否则为1
        num_targets = not_ignore.long().sum().item(
        )  # 计算target中的非self.pad_id的数量

        correct = (shift_labels
                   == preds) & not_ignore  # 计算model预测正确的token的个数,排除pad的tokne
        correct = correct.float().sum()

        accuracy = correct / num_targets
        loss = loss / num_targets
        return loss, accuracy

    def preprocess_raw_data(self, subset):
        """
        对原始语料进行处理,将原始语料转换为用于train的token id,对于每个dialogue,将其处于成如下形式"[CLS]utterance1[SEP]utterance2[SEP]utterance3[SEP]"
        :param args:
        :param tokenizer:
        :param n_ctx:GPT2模型的上下文窗口大小,对于超过n_ctx(n_ctx包括了特殊字符)的dialogue进行截断
        :return:
        """
        self.logger.info(
            "tokenizing raw data,raw data path:{}, token output path:{}".
            format(args.train_raw_path, args.train_tokenized_path))
        if subset == 'train':
            raw_path = self.args.train_raw_path
        elif subset == 'valid':
            raw_path = self.args.valid_raw_path
        elif subset == 'test':
            raw_path = self.args.test_raw_path

        with open(raw_path, 'rb') as f:
            data = f.read().decode("utf-8")
        if "\r\n" in data:
            train_data = data.split("\r\n\r\n")
        else:
            train_data = data.split("\n\n")
        self.logger.info("there are {} dialogue in raw dataset".format(
            len(train_data)))
        if subset == 'train':
            path = self.args.train_tokenized_path
        elif subset == 'valid':
            path = self.args.valid_tokenized_path
        elif subset == 'test':
            path = self.args.test_tokenized_path
        with open(path, "w", encoding="utf-8") as f:
            for dialogue_index, dialogue in enumerate(tqdm(train_data)):
                if "\r\n" in data:
                    utterances = dialogue.split("\r\n")
                else:
                    utterances = dialogue.split("\n")
                # dialogue_ids = [tokenizer.cls_token_id]  # 每个dialogue以[CLS]开头
                dialogue_ids = []  # 每个dialogue以[CLS]开头
                for utterance in utterances:
                    dialogue_ids.extend([
                        self.tokenizer.convert_tokens_to_ids(word)
                        for word in utterance
                    ])
                    dialogue_ids.append(self.tokenizer.sep_token_id
                                        )  # 每个utterance之后添加[SEP],表示utterance结束
                # 对超过n_ctx的长度进行截断,否则GPT2模型会报错
                ###############################m
                dialogue_ids = [self.tokenizer.cls_token_id
                                ] + dialogue_ids[-self.n_ctx + 1:]
                # dialogue_ids = dialogue_ids[:n_ctx]
                for dialogue_id in dialogue_ids:
                    f.write(str(dialogue_id) + ' ')
                # 最后一条记录不添加换行符
                if dialogue_index < len(train_data) - 1:
                    f.write("\n")
        self.logger.info(
            "finish preprocessing raw data,the result is stored in {}".format(
                self.args.train_tokenized_path))

    def collate_fn(self, batch):
        """
        计算该batch中的所有sample的最长的input,并且将其他input的长度向其对齐
        :param batch:
        :return:
        """
        input_ids = []
        mask_rs = []
        btc_size = len(batch)
        max_input_len = 0  # 该batch中最长的input,用于该batch的数据对齐
        # 计算该batch中input的最大长度
        # for btc_idx in range(btc_size):
        #     if max_input_len < len(batch[btc_idx]):
        #         max_input_len = len(batch[btc_idx])
        # 使用pad_id对小于max_input_len的input_id进行补全
        # for btc_idx in range(btc_size):
        #     input_len = len(batch[btc_idx])
        #     input_ids.append(batch[btc_idx])
        #     input_ids[btc_idx].extend([pad_id] * (max_input_len - input_len))

        # 计算该batch中input的最大长度
        for btc_idx, (inputs, mask_r) in enumerate(batch):
            if max_input_len < len(inputs):
                max_input_len = len(inputs)
        # 使用pad_id对小于max_input_len的input_id进行补全
        for btc_idx, (inputs, mask_r) in enumerate(batch):
            assert len(inputs) == len(mask_r), f"{len(inputs)}, {len(mask_r)}"
            input_len = len(inputs)
            input_ids.append(inputs)
            input_ids[btc_idx].extend([self.pad_id] *
                                      (max_input_len - input_len))
            mask_rs.append(mask_r)
            mask_rs[btc_idx].extend([self.pad_id] *
                                    (max_input_len - input_len))
        # self.logger.info(torch.tensor(input_ids, dtype=torch.long).shape)
        # self.logger.info(torch.tensor(mask_rs, dtype=torch.long).shape)
        return (torch.tensor(input_ids, dtype=torch.long),
                torch.tensor(mask_rs, dtype=torch.long))

    def vector2sentence(self, batch_sen):
        # 一个batch的sentence 从id换成token
        sentences = []
        for sen in batch_sen.numpy().tolist():
            sentence = []
            for word in sen:
                if word > 3:
                    sentence.append(self.index2word[word])
                elif word == 3:
                    sentence.append('_UNK_')
            sentences.append(sentence)
        return sentences

    @classmethod
    def optim_opts(self):
        """
        Fetch optimizer selection.

        By default, collects everything in torch.optim, as well as importing:
        - qhm / qhmadam if installed from github.com/facebookresearch/qhoptim

        Override this (and probably call super()) to add your own optimizers.
        """
        # first pull torch.optim in
        optims = {
            k.lower(): v
            for k, v in optim.__dict__.items()
            if not k.startswith('__') and k[0].isupper()
        }
        try:
            import apex.optimizers.fused_adam as fused_adam
            optims['fused_adam'] = fused_adam.FusedAdam
        except ImportError:
            pass

        try:
            # https://openreview.net/pdf?id=S1fUpoR5FQ
            from qhoptim.pyt import QHM, QHAdam
            optims['qhm'] = QHM
            optims['qhadam'] = QHAdam
        except ImportError:
            # no QHM installed
            pass
        self.logger.info(optims)
        return optims

    def init_optim(self):
        """
        Initialize optimizer with model parameters.

        :param params:
            parameters from the model

        :param optim_states:
            optional argument providing states of optimizer to load

        :param saved_optim_type:
            type of optimizer being loaded, if changed will skip loading
            optimizer states
        """
        # 设置优化器,并且在初始训练时,使用warmup策略
        self.optimizer = transformers.AdamW(self.model.parameters(),
                                            lr=self.args.lr,
                                            correct_bias=True)
        self.scheduler = transformers.WarmupLinearSchedule(
            self.optimizer,
            warmup_steps=self.args.warmup_steps,
            t_total=self.total_steps)

    def backward(self, loss):
        """
        Perform a backward pass. It is recommended you use this instead of
        loss.backward(), for integration with distributed training and FP16
        training.
        """
        loss.backward()

    def update_params(self):
        """
        Perform step of optimization, clipping gradients and adjusting LR
        schedule if needed. Gradient accumulation is also performed if agent
        is called with --update-freq.

        It is recommended (but not forced) that you call this in train_step.
        """
        update_freq = 1
        if update_freq > 1:
            # we're doing gradient accumulation, so we don't only want to step
            # every N updates instead
            self._number_grad_accum = (self._number_grad_accum +
                                       1) % update_freq
            if self._number_grad_accum != 0:
                return
        #0.1是不是太小了,原版就是这样
        if self.opt['gradient_clip'] > 0:
            torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                           self.opt['gradient_clip'])

        self.optimizer.step()

    def zero_grad(self):
        """
        Zero out optimizer.

        It is recommended you call this in train_step. It automatically handles
        gradient accumulation if agent is called with --update-freq.
        """
        self.optimizer.zero_grad()
示例#22
0
class AdvTrainer(BaseTrainer):
    def __init__(self, args):
        super(AdvTrainer, self).__init__(args)

    def make_model_env(self, gpu, ngpus_per_node):
        if self.args.distributed:
            self.args.gpu = self.args.devices[gpu]
        else:
            self.args.gpu = 0

        if self.args.use_cuda and self.args.distributed:
            # For multiprocessing distributed training, rank needs to be the
            # global rank among all the processes
            self.args.rank = self.args.rank * ngpus_per_node + gpu
            dist.init_process_group(backend=self.args.dist_backend,
                                    init_method=self.args.dist_url,
                                    world_size=self.args.world_size,
                                    rank=self.args.rank)

        self.model = DomainQA(self.args.bert_model, self.args.num_classes,
                              self.args.hidden_size, self.args.num_layers,
                              self.args.dropout, self.args.dis_lambda,
                              self.args.concat, self.args.anneal)

        if self.args.load_model is not None:
            print("Loading model from ", self.args.load_model)
            self.model.load_state_dict(
                torch.load(self.args.load_model,
                           map_location=lambda storage, loc: storage))

        if self.args.freeze_bert:
            for param in self.model.bert.parameters():
                param.requires_grad = False

        max_len = max([len(f) for f in self.features_lst])
        num_train_optimization_steps = math.ceil(
            max_len / self.args.batch_size) * self.args.epochs * len(
                self.features_lst)

        qa_params = list(self.model.bert.named_parameters()) + list(
            self.model.qa_outputs.named_parameters())
        dis_params = list(self.model.discriminator.named_parameters())
        self.qa_optimizer = get_opt(qa_params, num_train_optimization_steps,
                                    self.args)
        self.dis_optimizer = get_opt(dis_params, num_train_optimization_steps,
                                     self.args)

        if self.args.use_cuda:
            if self.args.distributed:
                torch.cuda.set_device(self.args.gpu)
                self.model.cuda(self.args.gpu)
                self.args.batch_size = int(self.args.batch_size /
                                           ngpus_per_node)
                self.args.workers = int(
                    (self.args.workers + ngpus_per_node - 1) / ngpus_per_node)
                self.model = DistributedDataParallel(
                    self.model,
                    device_ids=[self.args.gpu],
                    find_unused_parameters=True)
            else:
                self.model.cuda()
                self.model = DataParallel(self.model,
                                          device_ids=self.args.devices)

        cudnn.benchmark = True

    def train(self):
        step = 1
        avg_qa_loss = 0
        avg_dis_loss = 0
        iter_lst = [self.get_iter(self.features_lst, self.args)]
        num_batches = sum([len(iterator[0]) for iterator in iter_lst])
        for epoch in range(self.args.start_epoch,
                           self.args.start_epoch + self.args.epochs):
            start = time.time()
            self.model.train()
            batch_step = 1
            for data_loader, sampler in iter_lst:
                if self.args.distributed:
                    sampler.set_epoch(epoch)

                for i, batch in enumerate(data_loader, start=1):
                    input_ids, input_mask, seg_ids, start_positions, end_positions, labels = batch

                    # remove unnecessary pad token
                    seq_len = torch.sum(torch.sign(input_ids), 1)
                    max_len = torch.max(seq_len)

                    input_ids = input_ids[:, :max_len].clone()
                    input_mask = input_mask[:, :max_len].clone()
                    seg_ids = seg_ids[:, :max_len].clone()
                    start_positions = start_positions.clone()
                    end_positions = end_positions.clone()

                    if self.args.use_cuda:
                        input_ids = input_ids.cuda(self.args.gpu,
                                                   non_blocking=True)
                        input_mask = input_mask.cuda(self.args.gpu,
                                                     non_blocking=True)
                        seg_ids = seg_ids.cuda(self.args.gpu,
                                               non_blocking=True)
                        start_positions = start_positions.cuda(
                            self.args.gpu, non_blocking=True)
                        end_positions = end_positions.cuda(self.args.gpu,
                                                           non_blocking=True)

                    qa_loss = self.model(input_ids,
                                         seg_ids,
                                         input_mask,
                                         start_positions,
                                         end_positions,
                                         labels,
                                         dtype="qa",
                                         global_step=step)
                    qa_loss = qa_loss.mean()
                    qa_loss.backward()

                    # update qa model
                    avg_qa_loss = self.cal_running_avg_loss(
                        qa_loss.item(), avg_qa_loss)
                    self.qa_optimizer.step()
                    self.qa_optimizer.zero_grad()

                    # update discriminator
                    dis_loss = self.model(input_ids,
                                          seg_ids,
                                          input_mask,
                                          start_positions,
                                          end_positions,
                                          labels,
                                          dtype="dis",
                                          global_step=step)
                    dis_loss = dis_loss.mean()
                    dis_loss.backward()
                    avg_dis_loss = self.cal_running_avg_loss(
                        dis_loss.item(), avg_dis_loss)
                    self.dis_optimizer.step()
                    self.dis_optimizer.zero_grad()
                    step += 1
                    if epoch != 0 and i % 2000 == 0:
                        result_dict = self.evaluate_model(i)
                        for dev_file, f1 in result_dict.items():
                            print("GPU/CPU {} evaluated {}: {:.2f}".format(
                                self.args.gpu, dev_file, f1),
                                  end="\n")

                    batch_step += 1
                    msg = "{}/{} {} - ETA : {} - QA loss: {:.4f}, DIS loss: {:.4f}" \
                        .format(batch_step, num_batches, progress_bar(batch_step, num_batches),
                                eta(start, batch_step, num_batches),
                                avg_qa_loss, avg_dis_loss)
                    print(msg, end="\r")

            print(
                "[GPU Num: {}, Epoch: {}, Final QA loss: {:.4f}, Final DIS loss: {:.4f}]"
                .format(self.args.gpu, epoch, avg_qa_loss, avg_dis_loss))

            # save model
            if not self.args.distributed or self.args.rank == 0:
                self.save_model(epoch, avg_qa_loss)

            if self.args.do_valid:
                result_dict = self.evaluate_model(epoch)
                for dev_file, f1 in result_dict.items():
                    print("GPU/CPU {} evaluated {}: {:.2f}".format(
                        self.args.gpu, dev_file, f1),
                          end="\n")
示例#23
0
class UNetTrainer(object):
    """UNet trainer"""
    def __init__(self,
                 start_epoch=0,
                 save_dir='',
                 resume="",
                 devices_num=2,
                 num_classes=2,
                 color_dim=1):
        self.net = UNet(color_dim=color_dim, num_classes=num_classes)
        self.start_epoch = start_epoch if start_epoch != 0 else 1
        self.save_dir = os.path.join('../models', save_dir)
        self.loss = CrossEntropyLoss()
        self.num_classes = num_classes
        if resume:
            checkpoint = torch.load(resume)
            if self.start_epoch == 0:
                self.start_epoch = checkpoint['epoch'] + 1
            if not self.save_dir:
                self.save_dir = checkpoint['save_dir']
            self.net.load_state_dict(checkpoint['state_dir'])
        if not os.path.exists(self.save_dir):
            os.makedirs(self.save_dir)
        # self.net.cuda()
        # self.loss.cuda()
        if devices_num == 2:
            self.net = DataParallel(self.net, device_ids=[0, 1])
        # self.loss = DataParallel(self.loss, device_ids=[0, 1])

    def train(self,
              train_loader,
              val_loader,
              lr=0.001,
              weight_decay=1e-4,
              epochs=200,
              save_freq=10):
        self.logfile = os.path.join(self.save_dir, 'log')
        sys.stdout = Logger(self.logfile)
        self.epochs = epochs
        self.lr = lr
        optimizer = torch.optim.Adam(
            self.net.parameters(),
            # lr,
            # momentum=0.9,
            weight_decay=weight_decay)
        for epoch in range(self.start_epoch, epochs + 1):
            self.train_(train_loader, epoch, optimizer, save_freq)
            self.validate_(val_loader, epoch)

    def train_(self, data_loader, epoch, optimizer, save_freq=10):
        start_time = time.time()
        self.net.train()
        # lr = self.get_lr(epoch)
        # for param_group in optimizer.param_groups:
        # param_group['lr'] = lr
        metrics = []
        for i, (data, target) in enumerate(tqdm(data_loader)):
            data_t, target_t = data, target
            # data = Variable(data.cuda(async = True))
            # target = Variable(target.cuda(async = True))
            data = Variable(data)
            target = Variable(target)
            output = self.net(data)  # UNet输出结果
            output = output.transpose(1, 3).transpose(1, 2).contiguous().view(
                -1, self.num_classes)
            target = target.view(-1)
            loss_output = self.loss(output, target)
            optimizer.zero_grad()
            loss_output.backward()  # 反向传播Loss
            optimizer.step()
            loss_output = loss_output.data.item()  # Loss数值
            acc = accuracy(output, target)
            metrics.append([loss_output, acc])
            if i == 0:
                batch_size = data.size(0)
                _, output = output.data.max(
                    dim=1
                )  #  _为最大值,output 为output.data.max(按行)的最大值的索引,如果为0 ,则第一层的卷积出来的数大,反之为第二层
                output = output.view(batch_size, 1, 1, 320, 480).cpu()  # 预测结果图
                data_t = data_t[0, 0].unsqueeze(0).unsqueeze(0)  # 原img图
                target_t = target_t[0].unsqueeze(0)  # gt图
                t = torch.cat([output[0].float(), data_t, target_t.float()], 0)
                # 第一个参数为list,拼接3张图像
                # show_list = []
                # for j in range(10):
                #    show_list.append(data_t[j, 0].unsqueeze(0).unsqueeze(0))
                #    show_list.append(target_t[j].unsqueeze(0))
                #    show_list.append(output[j].float())
                #
                # t = torch.cat(show_list, 0)
                torchvision.utils.save_image(t,
                                             "../Try/temp/%02d_train.jpg" %
                                             epoch,
                                             nrow=3)
            # if i == 20:
            # break
        if epoch % save_freq == 0:
            if 'module' in dir(self.net):
                state_dict = self.net.module.state_dict()
            else:
                state_dict = self.net.state_dict()
            for key in state_dict.keys():
                state_dict[key] = state_dict[key].cpu()
            torch.save(
                {
                    'epoch': epoch,
                    'save_dir': self.save_dir,
                    'state_dir': state_dict
                }, os.path.join(self.save_dir, '%03d.ckpt' % epoch))
        end_time = time.time()
        metrics = np.asarray(metrics, np.float32)
        self.print_metrics(metrics, 'Train', end_time - start_time, epoch)

    def validate_(self, data_loader, epoch):
        start_time = time.time()
        self.net.eval()
        metrics = []
        for i, (data, target) in enumerate(tqdm(data_loader)):
            data_t, target_t = data, target
            # data = Variable(data.cuda(async = True), volatile = True)
            # target = Variable(target.cuda(async = True), volatile = True)
            data = Variable(data, requires_grad=False)
            target = Variable(target, requires_grad=False)
            output = self.net(data)
            output = output.transpose(1, 3).transpose(1, 2).contiguous().view(
                -1, self.num_classes)
            target = target.view(-1)
            loss_output = self.loss(output, target)
            loss_output = loss_output.data.item()
            acc = accuracy(output, target)
            metrics.append([loss_output, acc])
            if i == 0:
                batch_size = data.size(0)
                _, output = output.data.max(dim=1)
                output = output.view(batch_size, 1, 1, 320, 480).cpu()
                data_t = data_t[0, 0].unsqueeze(0).unsqueeze(0)
                target_t = target_t[0].unsqueeze(0)
                t = torch.cat([output[0].float(), data_t, target_t.float()], 0)
                # show_list = []
                # for j in range(10):
                #   show_list.append(data_t[j, 0].unsqueeze(0).unsqueeze(0))
                #   show_list.append(target_t[j].unsqueeze(0))
                #   show_list.append(output[j].float())
                #
                # t = torch.cat(show_list, 0)
                torchvision.utils.save_image(t,
                                             "../Try/temp/%02d_train.jpg" %
                                             epoch,
                                             nrow=3)
            # if i == 10:
            #   break
        end_time = time.time()
        metrics = np.asarray(metrics, np.float32)
        self.print_metrics(metrics, 'Validation', end_time - start_time)

    def print_metrics(self, metrics, phase, time, epoch=-1):
        """metrics: [loss, acc]
        """
        if epoch != -1:
            print("Epoch: {}".format(epoch), )
        print(phase, )
        print('loss %2.4f, accuracy %2.4f, time %2.2f' %
              (np.mean(metrics[:, 0]), np.mean(metrics[:, 1]), time))
        if phase != 'Train':
            print()

    def get_lr(self, epoch):
        if epoch <= self.epochs * 0.5:
            lr = self.lr
        elif epoch <= self.epochs * 0.8:
            lr = 0.1 * self.lr
        else:
            lr = 0.01 * self.lr
        return lr

    def save_py_files(self, path):
        """copy .py files in exps dir, cfgs dir and current dir into
            save_dir, and keep the files structure
        """
        # exps dir
        pyfiles = [f for f in os.listdir(path) if f.endswith('.py')]
        path = "/".join(path.split('/')[-2:])
        exp_save_path = os.path.join(self.save_dir, path)
        mkdir(exp_save_path)
        for f in pyfiles:
            shutil.copy(os.path.join(path, f), os.path.join(exp_save_path, f))
        # current dir
        pyfiles = [f for f in os.listdir('./') if f.endswith('.py')]
        for f in pyfiles:
            shutil.copy(f, os.path.join(self.save_dir, f))
        # cfgs dir
        shutil.copytree('./cfgs', os.path.join(self.save_dir, 'cfgs'))
示例#24
0
def train(args):
    model, model_file = create_model(args.encoder_type,
                                     work_dir=args.work_dir,
                                     ckp=args.ckp)
    model = model.cuda()

    loaders = get_train_val_loaders(batch_size=args.batch_size)

    #optimizer = RAdam([
    #    {'params': model.decoder.parameters(), 'lr': args.lr},
    #    {'params': model.encoder.parameters(), 'lr': args.lr / 10.},
    #])
    if args.optim_name == 'RAdam':
        optimizer = RAdam(model.parameters(), lr=args.lr)
    elif args.optim_name == 'Adam':
        optimizer = optim.Adam(model.parameters(), lr=args.lr)
    elif args.optim_name == 'SGD':
        optimizer = optim.SGD(model.parameters(), momentum=0.9, lr=args.lr)

    #model, optimizer = amp.initialize(model, optimizer, opt_level="O1",verbosity=0)

    if torch.cuda.device_count() > 1:
        model = DataParallel(model)

    if args.lrs == 'plateau':
        lr_scheduler = ReduceLROnPlateau(optimizer,
                                         mode='max',
                                         factor=args.factor,
                                         patience=args.patience,
                                         min_lr=args.min_lr)
    else:
        lr_scheduler = CosineAnnealingLR(optimizer,
                                         args.t_max,
                                         eta_min=args.min_lr)

    best_metrics = 0.
    best_key = 'dice'

    print(
        'epoch |    lr    |      %        |  loss  |  avg   |   loss |  dice  |  best  | time |  save |'
    )

    if not args.no_first_val:
        val_metrics = validate(args, model, loaders['valid'])
        print(
            'val   |          |               |        |        | {:.4f} | {:.4f} | {:.4f} |        |        |'
            .format(val_metrics['loss'], val_metrics['dice'],
                    val_metrics['dice']))

        best_metrics = val_metrics[best_key]

    if args.val:
        return

    model.train()

    #if args.lrs == 'plateau':
    #    lr_scheduler.step(best_metrics)
    #else:
    #    lr_scheduler.step()
    train_iter = 0

    for epoch in range(args.num_epochs):
        train_loss = 0

        current_lr = get_lrs(optimizer)
        bg = time.time()
        for batch_idx, data in enumerate(loaders['train']):
            train_iter += 1
            img, targets = data[0].cuda(), data[1].cuda()
            batch_size = img.size(0)

            outputs = model(img)
            loss = _reduce_loss(criterion(outputs, targets))
            (loss).backward()

            #with amp.scale_loss(loss*batch_size, optimizer) as scaled_loss:
            #    scaled_loss.backward()

            if batch_idx % 4 == 0:
                optimizer.step()
                optimizer.zero_grad()

            train_loss += loss.item()
            print('\r {:4d} | {:.6f} | {:06d}/{} | {:.4f} | {:.4f} |'.format(
                epoch, float(current_lr[0]),
                args.batch_size * (batch_idx + 1), loaders['train'].num,
                loss.item(), train_loss / (batch_idx + 1)),
                  end='')

            if train_iter > 0 and train_iter % args.iter_val == 0:
                save_model(model, model_file + '_latest')
                val_metrics = validate(args, model, loaders['valid'])

                _save_ckp = ''
                if val_metrics[best_key] > best_metrics:
                    best_metrics = val_metrics[best_key]
                    save_model(model, model_file)
                    _save_ckp = '*'
                print(' {:.4f} | {:.4f} | {:.4f} | {:.2f} |  {:4s} |'.format(
                    val_metrics['loss'], val_metrics['dice'], best_metrics,
                    (time.time() - bg) / 60, _save_ckp))

                model.train()

                if args.lrs == 'plateau':
                    lr_scheduler.step(best_metrics)
                else:
                    lr_scheduler.step()
                current_lr = get_lrs(optimizer)
示例#25
0
class Trainer:
    def _init_dataset(self):
        if self.cfg['dataset_type'] == 'pose_feats':
            train_set = RelPoseFeatsDataset(self.cfg['dataset_root'],
                                            self.cfg['dataset_extract_name'],
                                            self.cfg['dataset_match_name'],
                                            self.cfg['train_pair_info_fn'],
                                            self.cfg['epipolar_inlier_thresh'],
                                            self.cfg['use_eig'],
                                            self.cfg['dataset_eig_name'],
                                            self.cfg['use_feats'],
                                            is_train=True)
            val_set = RelPoseFeatsDataset(self.cfg['dataset_root'],
                                          self.cfg['dataset_extract_name'],
                                          self.cfg['dataset_match_name'],
                                          self.cfg['val_pair_info_fn'],
                                          self.cfg['epipolar_inlier_thresh'],
                                          self.cfg['use_eig'],
                                          self.cfg['dataset_eig_name'],
                                          self.cfg['use_feats'],
                                          is_train=False)
        elif self.cfg['dataset_type'] == 'detrac':
            root_dir = 'data/detrac_train_cache' if 'root_dir' not in self.cfg else self.cfg[
                'root_dir']
            train_set = DETRACTrainDataset(self.cfg['train_pair_info_fn'],
                                           root_dir,
                                           self.cfg['dataset_extract_name'],
                                           self.cfg['dataset_match_name'],
                                           self.cfg['use_eig'],
                                           self.cfg['eig_name'], True, True)
            val_set = DETRACTrainDataset(self.cfg['val_pair_info_fn'],
                                         root_dir,
                                         self.cfg['dataset_extract_name'],
                                         self.cfg['dataset_match_name'],
                                         self.cfg['use_eig'],
                                         self.cfg['eig_name'], False)
        else:
            raise NotImplementedError

        self.train_set = DataLoader(train_set,
                                    self.cfg['batch_size'],
                                    True,
                                    num_workers=16,
                                    pin_memory=False,
                                    collate_fn=collate_fn)
        self.val_set = DataLoader(val_set,
                                  self.cfg['batch_size'],
                                  False,
                                  num_workers=4,
                                  collate_fn=collate_fn)
        print(f'train set len {len(self.train_set)}')
        print(f'val set len {len(self.val_set)}')

    def _init_network(self):
        self.network = LMCNet(self.cfg).cuda()
        self.optimizer = Adam(self.network.parameters(), lr=1e-3)

        self.val_losses = []
        for loss_name in self.cfg['loss']:
            self.val_losses.append(name2loss[loss_name](self.cfg))
        self.val_metrics = []

        for metric_name in self.cfg['val_metric']:
            if metric_name in name2metric:
                self.val_metrics.append(name2metric[metric_name](self.cfg))
            else:
                self.val_metrics.append(name2loss[metric_name](self.cfg))

        if self.cfg['multi_gpus']:
            # make multi gpu network
            self.train_network = DataParallel(
                MultiGPUWrapper(self.network, self.val_losses))
            self.train_losses = [DummyLoss(self.val_losses)]
        else:
            self.train_network = self.network
            self.train_losses = self.val_losses

        if 'finetune' in self.cfg and self.cfg['finetune']:
            checkpoint = torch.load(self.cfg['finetune_path'])
            self.network.load_state_dict(checkpoint['network_state_dict'])
            print(f'==> resuming from step {self.cfg["finetune_path"]}')
        self.val_evaluator = ValidationEvaluator(self.cfg)

    def __init__(self, cfg):
        self.cfg = cfg
        self.model_dir = os.path.join('data/model', cfg['name'])
        if not os.path.exists(self.model_dir): os.mkdir(self.model_dir)
        self.pth_fn = os.path.join(self.model_dir, 'model.pth')
        self.best_pth_fn = os.path.join(self.model_dir, 'model_best.pth')

    def run(self):
        self._init_dataset()
        self._init_network()
        self._init_logger()

        best_para, start_step = self._load_model()
        train_iter = iter(self.train_set)

        pbar = tqdm(total=self.cfg['total_step'], bar_format='{r_bar}')
        pbar.update(start_step)
        for step in range(start_step, self.cfg['total_step']):
            try:
                train_data = next(train_iter)
            except StopIteration:
                train_iter = iter(self.train_set)
                train_data = next(train_iter)
            if not self.cfg['multi_gpus']:
                train_data = to_cuda(train_data)
            train_data['step'] = step

            self.train_network.train()
            self.network.train()
            reset_learning_rate(self.optimizer, self._get_lr(step))
            self.optimizer.zero_grad()
            self.train_network.zero_grad()

            log_info = {}
            outputs = self.train_network(train_data)
            for loss in self.train_losses:
                loss_results = loss(outputs, train_data, step)
                for k, v in loss_results.items():
                    log_info[k] = v

            loss = 0
            for k, v in log_info.items():
                if k.startswith('loss'):
                    loss = loss + torch.mean(v)

            loss.backward()
            self.optimizer.step()
            if ((step + 1) % self.cfg['train_log_step']) == 0:
                self._log_data(log_info, step + 1, 'train')

            if (step + 1) % self.cfg['val_interval'] == 0:
                val_results, val_para = self.val_evaluator(
                    self.network, self.val_losses + self.val_metrics,
                    self.val_set)
                if val_para > best_para:
                    print(
                        f'New best model {self.cfg["key_metric_name"]}: {val_para:.5f} previous {best_para:.5f}'
                    )
                    best_para = val_para
                    # if self.cfg['save_inter_model'] and (step+1)%self.cfg['save_inter_interval']==0:
                    #     self._save_model(step + 1, best_para, os.path.join(self.model_dir,f'{step+1}.pth'))
                    self._save_model(step + 1, best_para, self.best_pth_fn)
                self._log_data(val_results, step + 1, 'val')

            if (step + 1) % self.cfg['save_interval'] == 0:
                # if self.cfg['save_inter_model'] and (step+1)%10000==0:
                #     self._save_model(step+1,best_para,f'{self.model_dir}/{step}.pth')
                self._save_model(step + 1, best_para)

            pbar.set_postfix(loss=float(loss.detach().cpu().numpy()))
            pbar.update(1)

        pbar.close()

    def _load_model(self):
        best_para, start_step = 0, 0
        if os.path.exists(self.pth_fn):
            checkpoint = torch.load(self.pth_fn)
            best_para = checkpoint['best_para']
            start_step = checkpoint['step']
            self.network.load_state_dict(checkpoint['network_state_dict'])
            self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            print(f'==> resuming from step {start_step} best para {best_para}')

        return best_para, start_step

    def _save_model(self, step, best_para, save_fn=None):
        save_fn = self.pth_fn if save_fn is None else save_fn
        torch.save(
            {
                'step': step,
                'best_para': best_para,
                'network_state_dict': self.network.state_dict(),
                'optimizer_state_dict': self.optimizer.state_dict(),
            }, save_fn)

    def _init_logger(self):
        self.logger = Logger(self.model_dir)

    def _log_data(self, results, step, prefix='train', verbose=False):
        log_results = {}
        for k, v in results.items():
            if isinstance(v, float) or np.isscalar(v):
                log_results[k] = v
            elif type(v) == np.ndarray:
                log_results[k] = np.mean(v)
            else:
                log_results[k] = np.mean(v.detach().cpu().numpy())
        self.logger.log(log_results, prefix, step, verbose)

    def _get_lr(self, step):
        if 'lr_type' not in self.cfg or self.cfg['lr_type'] == 'default':
            if step <= self.cfg['lr_mid_epoch']:
                return self.cfg['lr_start']
            else:
                decay_rate = self.cfg['lr_decay_rate']
                decay_step = self.cfg['lr_decay_step']
                decay_num = (step - self.cfg['lr_mid_epoch']) // decay_step
                return max(self.cfg['lr_start'] * decay_rate**decay_num,
                           self.cfg['lr_min'])
        elif self.cfg['lr_type'] == 'warm_up':
            if step <= self.cfg['lr_warm_up_step']:
                return self.cfg['lr_warm_up']
            else:
                decay_rate = self.cfg['lr_decay_rate']
                decay_step = self.cfg['lr_decay_step']
                decay_num = (step - self.cfg['lr_warm_up_step']) // decay_step
                return max(self.cfg['lr_start'] * decay_rate**decay_num,
                           self.cfg['lr_min'])
示例#26
0
                                            milestones=[36, 52, 58],
                                            gamma=0.1)

net = net.cuda()
ArcMargin = ArcMargin.cuda()
if multi_gpus:
    net = DataParallel(net)
    ArcMargin = DataParallel(ArcMargin)
criterion = torch.nn.CrossEntropyLoss()

best_acc = 0.0
best_epoch = 0
for epoch in range(start_epoch, TOTAL_EPOCH + 1):
    # train model
    _print('Train Epoch: {}/{} ...'.format(epoch, TOTAL_EPOCH))
    net.train()

    train_total_loss = 0.0
    total = 0
    since = time.time()
    for data in trainloader:
        img, label = data[0].cuda(), data[1].cuda()
        batch_size = img.size(0)
        optimizer_ft.zero_grad()

        raw_logits = net(img)

        output = ArcMargin(raw_logits, label)
        total_loss = criterion(output, label)
        total_loss.backward()
        optimizer_ft.step()
示例#27
0
class img2poseModel:
    def __init__(
        self,
        depth,
        min_size,
        max_size,
        model_path=None,
        device=None,
        pose_mean=None,
        pose_stddev=None,
        distributed=False,
        gpu=0,
        threed_68_points=None,
        threed_5_points=None,
        rpn_pre_nms_top_n_test=6000,
        rpn_post_nms_top_n_test=1000,
        bbox_x_factor=1.1,
        bbox_y_factor=1.1,
        expand_forehead=0.3,
    ):
        self.depth = depth
        self.min_size = min_size
        self.max_size = max_size
        self.model_path = model_path
        self.distributed = distributed
        self.gpu = gpu

        if device is None:
            self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        else:
            self.device = device

        # create network backbone
        backbone = resnet_fpn_backbone(f"resnet{self.depth}", pretrained=False)

        if pose_mean is not None:
            pose_mean = torch.tensor(pose_mean)
            pose_stddev = torch.tensor(pose_stddev)

        if threed_68_points is not None:
            threed_68_points = torch.tensor(threed_68_points)

        if threed_5_points is not None:
            threed_5_points = torch.tensor(threed_5_points)

        # create the feature pyramid network
        self.fpn_model = FasterDoFRCNN(
            backbone,
            2,
            min_size=self.min_size,
            max_size=self.max_size,
            pose_mean=pose_mean,
            pose_stddev=pose_stddev,
            threed_68_points=threed_68_points,
            threed_5_points=threed_5_points,
            rpn_pre_nms_top_n_test=rpn_pre_nms_top_n_test,
            rpn_post_nms_top_n_test=rpn_post_nms_top_n_test,
            bbox_x_factor=bbox_x_factor,
            bbox_y_factor=bbox_y_factor,
            expand_forehead=expand_forehead,
        )

        # if using cpu, remove the parallel modules from the saved model
        self.fpn_model_without_ddp = self.fpn_model

        if self.distributed:
            self.fpn_model = self.fpn_model.to(self.device)
            self.fpn_model = DistributedDataParallel(
                self.fpn_model, device_ids=[self.gpu]
            )
            self.fpn_model_without_ddp = self.fpn_model.module

            print("Model will use distributed mode!")

        elif str(self.device) == "cpu":
            self.fpn_model = WrappedModel(self.fpn_model)
            self.fpn_model_without_ddp = self.fpn_model

            print("Model will run on CPU!")

        else:
            self.fpn_model = DataParallel(self.fpn_model)
            self.fpn_model = self.fpn_model.to(self.device)
            self.fpn_model_without_ddp = self.fpn_model

            print(f"Model will use {torch.cuda.device_count()} GPUs!")

        if self.model_path is not None:
            self.load_saved_model(self.model_path)
            self.evaluate()

    def load_saved_model(self, model_path):
        load_model(
            self.fpn_model_without_ddp, model_path, cpu_mode=str(self.device) == "cpu"
        )

    def evaluate(self):
        self.fpn_model.eval()

    def train(self):
        self.fpn_model.train()

    def run_model(self, imgs, targets=None):
        outputs = self.fpn_model(imgs, targets)

        return outputs

    def forward(self, imgs, targets):
        losses = self.run_model(imgs, targets)

        return losses

    def predict(self, imgs):
        assert self.fpn_model.training is False

        with torch.no_grad():
            predictions = self.run_model(imgs)

        return predictions
示例#28
0
    dataloader = DataLoader(dataset,
                            batch_size=BATCH_SIZE * 2,
                            num_workers=NUM_WORKERS)
    validate_num = len(dataset)
    validate_loss = 0
    with torch.no_grad():
        model.eval()

        for _, imgs, masks, _ in tqdm(dataloader):
            imgs = imgs.float().cuda()
            masks = masks.float().cuda()
            outputs = model(imgs)
            loss = criterion(outputs, masks)
            validate_loss += (loss.item() * imgs.shape[0])

        model.train()
    validate_loss /= validate_num
    lr = optimizer.param_groups[0]['lr']
    scheduler.step(validate_loss)
    print('Epoch{}:\t Train-{:.3f}\tValidate-{:.4f}\tlr-{}e-5'.format(
        epoch, train_loss, validate_loss, lr * 100000.))
    # agent.append(train_loss_record, epoch, train_loss)
    # agent.append(validate_loss_record, epoch, validate_loss)

    if validate_loss < min_loss:
        min_loss = validate_loss
        early_stop_counter = 0
        # if len(device_ids) > 1:
        #     torch.save(model.module.cpu().state_dict(), os.path.join(save_dir, 'model_{}.pth'.format(FOLD)))
        # else:
        #     torch.save(model.cpu().state_dict(), os.path.join(save_dir, 'model_{}.pth'.format(FOLD)))
示例#29
0
class ConditionalProGAN:
    """ Wrapper around the Generator and the Conditional Discriminator """

    def __init__(self, num_classes, depth=7, latent_size=512,
                 learning_rate=0.001, beta_1=0, beta_2=0.99,
                 eps=1e-8, drift=0.001, n_critic=1, use_eql=True,
                 loss="wgan-gp", use_ema=True, ema_decay=0.999,
                 device=th.device("cpu")):
        """
        constructor for the class
        :param num_classes: number of classes required for the conditional gan
        :param depth: depth of the GAN (will be used for each generator and discriminator)
        :param latent_size: latent size of the manifold used by the GAN
        :param learning_rate: learning rate for Adam
        :param beta_1: beta_1 for Adam
        :param beta_2: beta_2 for Adam
        :param eps: epsilon for Adam
        :param n_critic: number of times to update discriminator
                         (Used only if loss is wgan or wgan-gp)
        :param drift: drift penalty for the
                      (Used only if loss is wgan or wgan-gp)
        :param use_eql: whether to use equalized learning rate
        :param loss: the loss function to be used
                     Can either be a string =>
                          ["wgan-gp", "wgan", "lsgan", "lsgan-with-sigmoid",
                          "hinge", "standard-gan" or "relativistic-hinge"]
                     Or an instance of ConditionalGANLoss
        :param use_ema: boolean for whether to use exponential moving averages
        :param ema_decay: value of mu for ema
        :param device: device to run the GAN on (GPU / CPU)
        """

        from torch.optim import Adam
        from torch.nn import DataParallel

        # Create the Generator and the Discriminator
        self.gen = Generator(depth, latent_size, use_eql=use_eql).to(device)
        self.dis = ConditionalDiscriminator(
            num_classes, height=depth,
            feature_size=latent_size,
            use_eql=use_eql).to(device)

        # if code is to be run on GPU, we can use DataParallel:
        if device == th.device("cuda"):
            self.gen = DataParallel(self.gen)
            self.dis = DataParallel(self.dis)

        # state of the object
        self.latent_size = latent_size
        self.depth = depth
        self.use_ema = use_ema
        self.num_classes = num_classes  # required for matching aware
        self.ema_decay = ema_decay
        self.n_critic = n_critic
        self.use_eql = use_eql
        self.device = device
        self.drift = drift

        # define the optimizers for the discriminator and generator
        self.gen_optim = Adam(self.gen.parameters(), lr=learning_rate,
                              betas=(beta_1, beta_2), eps=eps)

        self.dis_optim = Adam(self.dis.parameters(), lr=learning_rate,
                              betas=(beta_1, beta_2), eps=eps)

        # define the loss function used for training the GAN
        self.loss = self.__setup_loss(loss)

        # setup the ema for the generator
        if self.use_ema:
            from pro_gan_pytorch.CustomLayers import update_average

            # create a shadow copy of the generator
            self.gen_shadow = copy.deepcopy(self.gen)

            # updater function:
            self.ema_updater = update_average

            # initialize the gen_shadow weights equal to the
            # weights of gen
            self.ema_updater(self.gen_shadow, self.gen, beta=0)

    def __setup_loss(self, loss):
        import pro_gan_pytorch.Losses as losses

        if isinstance(loss, str):
            loss = loss.lower()  # lowercase the string
            if loss == "wgan":
                loss = losses.CondWGAN_GP(self.dis, self.drift, use_gp=False)
                # note if you use just wgan, you will have to use weight clipping
                # in order to prevent gradient exploding

            elif loss == "wgan-gp":
                loss = losses.CondWGAN_GP(self.dis, self.drift, use_gp=True)

            elif loss == "lsgan":
                loss = losses.CondLSGAN(self.dis)

            elif loss == "lsgan-with-sigmoid":
                loss = losses.CondLSGAN_SIGMOID(self.dis)

            elif loss == "hinge":
                loss = losses.CondHingeGAN(self.dis)

            elif loss == "standard-gan":
                loss = losses.CondStandardGAN(self.dis)

            elif loss == "relativistic-hinge":
                loss = losses.CondRelativisticAverageHingeGAN(self.dis)

            else:
                raise ValueError("Unknown loss function requested")

        elif not isinstance(loss, losses.ConditionalGANLoss):
            raise ValueError("loss is neither an instance of GANLoss nor a string")

        return loss

    def __progressive_downsampling(self, real_batch, depth, alpha):
        """
        private helper for downsampling the original images in order to facilitate the
        progressive growing of the layers.
        :param real_batch: batch of real samples
        :param depth: depth at which training is going on
        :param alpha: current value of the fader alpha
        :return: real_samples => modified real batch of samples
        """

        from torch.nn import AvgPool2d
        from torch.nn.functional import interpolate

        # downsample the real_batch for the given depth
        down_sample_factor = int(np.power(2, self.depth - depth - 1))
        prior_downsample_factor = max(int(np.power(2, self.depth - depth)), 0)

        ds_real_samples = AvgPool2d(down_sample_factor)(real_batch)

        if depth > 0:
            prior_ds_real_samples = interpolate(AvgPool2d(prior_downsample_factor)(real_batch),
                                                scale_factor=2)
        else:
            prior_ds_real_samples = ds_real_samples

        # real samples are a combination of ds_real_samples and prior_ds_real_samples
        real_samples = (alpha * ds_real_samples) + ((1 - alpha) * prior_ds_real_samples)

        # return the so computed real_samples
        return real_samples

    def optimize_discriminator(self, noise, real_batch, labels, depth, alpha):
        """
        performs one step of weight update on discriminator using the batch of data
        :param noise: input noise of sample generation
        :param real_batch: real samples batch
        :param labels: (conditional classes) should be a list of integers
        :param depth: current depth of optimization
        :param alpha: current alpha for fade-in
        :return: current loss value
        """

        real_samples = self.__progressive_downsampling(real_batch, depth, alpha)

        loss_val = 0
        for _ in range(self.n_critic):
            # generate a batch of samples
            fake_samples = self.gen(noise, depth, alpha).detach()

            loss = self.loss.dis_loss(real_samples, fake_samples,
                                      labels, depth, alpha)

            # optimize discriminator
            self.dis_optim.zero_grad()
            loss.backward()
            self.dis_optim.step()

            loss_val += loss.item()

        return loss_val / self.n_critic

    def optimize_generator(self, noise, real_batch, labels, depth, alpha):
        """
        performs one step of weight update on generator for the given batch_size
        :param noise: input random noise required for generating samples
        :param real_batch: real batch of samples (real samples)
        :param labels: labels for conditional discrimination
        :param depth: depth of the network at which optimization is done
        :param alpha: value of alpha for fade-in effect
        :return: current loss (Wasserstein estimate)
        """

        # create batch of real samples
        real_samples = self.__progressive_downsampling(real_batch, depth, alpha)

        # generate fake samples:
        fake_samples = self.gen(noise, depth, alpha)

        # TODO_complete:
        # Change this implementation for making it compatible for relativisticGAN
        loss = self.loss.gen_loss(real_samples, fake_samples, labels, depth, alpha)

        # optimize the generator
        self.gen_optim.zero_grad()
        loss.backward()
        self.gen_optim.step()

        # if use_ema is true, apply ema to the generator parameters
        if self.use_ema:
            self.ema_updater(self.gen_shadow, self.gen, self.ema_decay)

        # return the loss value
        return loss.item()

    @staticmethod
    def create_grid(samples, scale_factor, img_file):
        """
        utility function to create a grid of GAN samples
        :param samples: generated samples for storing
        :param scale_factor: factor for upscaling the image
        :param img_file: name of file to write
        :return: None (saves a file)
        """
        from torchvision.utils import save_image
        from torch.nn.functional import interpolate

        # upsample the image
        if scale_factor > 1:
            samples = interpolate(samples, scale_factor=scale_factor)

        # save the images:
        save_image(samples, img_file, nrow=int(np.sqrt(len(samples))),
                   normalize=True, scale_each=True)

    @staticmethod
    def __save_label_info_file(label_file, labels):
        """
        utility method for saving a file with labels
        :param label_file: path to the file to be written
        :param labels: label tensor
        :return: None (writes file to disk)
        """
        # write file with the labels written one per line
        with open(label_file, "w") as fp:
            for label in labels:
                fp.write(str(label.item()) + "\n")

    def one_hot_encode(self, labels):
        """
        utility method to one-hot encode the labels
        :param labels: tensor of labels (Batch)
        :return: enc_label: encoded one_hot label
        """
        if not hasattr(self, "label_oh_encoder"):
            self.label_oh_encoder = th.nn.Embedding(self.num_classes, self.num_classes)
            self.label_oh_encoder.weight.data = th.eye(self.num_classes)

        return self.label_oh_encoder(labels.view(-1))

    def train(self, dataset, epochs, batch_sizes,
              fade_in_percentage, start_depth=0, num_workers=3, feedback_factor=100,
              log_dir="./models/", sample_dir="./samples/", save_dir="./models/",
              checkpoint_factor=1):
        """
        Utility method for training the ProGAN. Note that you don't have to necessarily use this
        you can use the optimize_generator and optimize_discriminator for your own training routine.
        :param dataset: object of the dataset used for training.
                        Note that this is not the dataloader (we create dataloader in this method
                        since the batch_sizes for resolutions can be different).
                        Get_item should return (Image, label) in that order
        :param epochs: list of number of epochs to train the network for every resolution
        :param batch_sizes: list of batch_sizes for every resolution
        :param fade_in_percentage: list of percentages of epochs per resolution
                                   used for fading in the new layer
                                   not used for first resolution, but dummy value still needed.
        :param start_depth: start training from this depth. def=0
        :param num_workers: number of workers for reading the data. def=3
        :param feedback_factor: number of logs per epoch. def=100
        :param log_dir: directory for saving the loss logs. def="./models/"
        :param sample_dir: directory for saving the generated samples. def="./samples/"
        :param checkpoint_factor: save model after these many epochs.
                                  Note that only one model is stored per resolution.
                                  during one resolution, the checkpoint will be updated (Rewritten)
                                  according to this factor.
        :param save_dir: directory for saving the models (.pth files)
        :return: None (Writes multiple files to disk)
        """
        from pro_gan_pytorch.DataTools import get_data_loader

        assert self.depth == len(batch_sizes), "batch_sizes not compatible with depth"

        # turn the generator and discriminator into train mode
        self.gen.train()
        self.dis.train()
        if self.use_ema:
            self.gen_shadow.train()

        # create a global time counter
        global_time = time.time()

        # create fixed_input for debugging
        temp_data_loader = get_data_loader(dataset, batch_sizes[0], num_workers=3)
        _, fx_labels = next(iter(temp_data_loader))
        # reshape them properly
        fixed_labels = self.one_hot_encode(fx_labels.view(-1, 1)).to(self.device)
        fixed_input = th.randn(fixed_labels.shape[0],
                               self.latent_size - self.num_classes).to(self.device)
        fixed_input = th.cat((fixed_labels, fixed_input), dim=-1)
        del temp_data_loader  # delete the temp data_loader since it is not required anymore

        os.makedirs(sample_dir, exist_ok=True)  # make sure the directory exists
        self.__save_label_info_file(os.path.join(sample_dir, "labels.txt"), fx_labels)

        print("Starting the training process ... ")
        for current_depth in range(start_depth, self.depth):

            print("\n\nCurrently working on Depth: ", current_depth)
            current_res = np.power(2, current_depth + 2)
            print("Current resolution: %d x %d" % (current_res, current_res))

            data = get_data_loader(dataset, batch_sizes[current_depth], num_workers)
            ticker = 1

            for epoch in range(1, epochs[current_depth] + 1):
                start = timeit.default_timer()  # record time at the start of epoch

                print("\nEpoch: %d" % epoch)
                total_batches = len(iter(data))

                fader_point = int((fade_in_percentage[current_depth] / 100)
                                  * epochs[current_depth] * total_batches)

                step = 0  # counter for number of iterations

                for (i, batch) in enumerate(data, 1):
                    # calculate the alpha for fading in the layers
                    alpha = ticker / fader_point if ticker <= fader_point else 1

                    # extract current batch of data for training
                    images, labels = batch
                    images = images.to(self.device)
                    labels = labels.view(-1, 1)

                    # create the input to the Generator
                    label_information = self.one_hot_encode(labels).to(self.device)
                    latent_vector = th.randn(images.shape[0],
                                             self.latent_size - self.num_classes).to(self.device)
                    gan_input = th.cat((label_information, latent_vector), dim=-1)

                    # optimize the discriminator:
                    dis_loss = self.optimize_discriminator(gan_input, images,
                                                           labels, current_depth, alpha)

                    # optimize the generator:
                    gen_loss = self.optimize_generator(gan_input, images,
                                                       labels, current_depth, alpha)

                    # provide a loss feedback
                    if i % int(total_batches / feedback_factor) == 0 or i == 1:
                        elapsed = time.time() - global_time
                        elapsed = str(datetime.timedelta(seconds=elapsed))
                        print("Elapsed: [%s]  batch: %d  d_loss: %f  g_loss: %f"
                              % (elapsed, i, dis_loss, gen_loss))

                        # also write the losses to the log file:
                        os.makedirs(log_dir, exist_ok=True)
                        log_file = os.path.join(log_dir, "loss_" + str(current_depth) + ".log")
                        with open(log_file, "a") as log:
                            log.write(str(step) + "\t" + str(dis_loss) +
                                      "\t" + str(gen_loss) + "\n")

                        # create a grid of samples and save it
                        os.makedirs(sample_dir, exist_ok=True)
                        gen_img_file = os.path.join(sample_dir, "gen_" + str(current_depth) +
                                                    "_" + str(epoch) + "_" +
                                                    str(i) + ".png")

                        # this is done to allow for more GPU space
                        self.gen_optim.zero_grad()
                        self.dis_optim.zero_grad()
                        with th.no_grad():
                            self.create_grid(
                                samples=self.gen(
                                    fixed_input,
                                    current_depth,
                                    alpha
                                ) if not self.use_ema
                                else self.gen_shadow(
                                    fixed_input,
                                    current_depth,
                                    alpha
                                ),
                                scale_factor=int(np.power(2, self.depth - current_depth - 1)),
                                img_file=gen_img_file,
                            )

                    # increment the alpha ticker and the step
                    ticker += 1
                    step += 1

                stop = timeit.default_timer()
                print("Time taken for epoch: %.3f secs" % (stop - start))

                if epoch % checkpoint_factor == 0 or epoch == 1 or epoch == epochs[current_depth]:
                    os.makedirs(save_dir, exist_ok=True)
                    gen_save_file = os.path.join(save_dir, "GAN_GEN_" + str(current_depth) + ".pth")
                    dis_save_file = os.path.join(save_dir, "GAN_DIS_" + str(current_depth) + ".pth")
                    gen_optim_save_file = os.path.join(save_dir,
                                                       "GAN_GEN_OPTIM_" + str(current_depth)
                                                       + ".pth")
                    dis_optim_save_file = os.path.join(save_dir,
                                                       "GAN_DIS_OPTIM_" + str(current_depth)
                                                       + ".pth")

                    th.save(self.gen.state_dict(), gen_save_file)
                    th.save(self.dis.state_dict(), dis_save_file)
                    th.save(self.gen_optim.state_dict(), gen_optim_save_file)
                    th.save(self.dis_optim.state_dict(), dis_optim_save_file)

                    # also save the shadow generator if use_ema is True
                    if self.use_ema:
                        gen_shadow_save_file = os.path.join(save_dir, "GAN_GEN_SHADOW_" +
                                                            str(current_depth) + ".pth")
                        th.save(self.gen_shadow.state_dict(), gen_shadow_save_file)

        # put the gen, shadow_gen and dis in eval mode
        self.gen.eval()
        self.dis.eval()
        if self.use_ema:
            self.gen_shadow.eval()

        print("Training completed ...")
def train_model(train_dataset, train_num_each, val_dataset, val_num_each):
    num_train = len(train_dataset)
    num_val = len(val_dataset)

    train_useful_start_idx = get_useful_start_idx(sequence_length, train_num_each)
    #print('train_useful_start_idx ',train_useful_start_idx )
    val_useful_start_idx = get_useful_start_idx(sequence_length, val_num_each)
    #print('test_useful_start_idx ', val_useful_start_idx)

    num_train_we_use = len(train_useful_start_idx) // num_gpu * num_gpu
    # print('num_train_we_use',num_train_we_use) #92166
    num_val_we_use = len(val_useful_start_idx) // num_gpu * num_gpu
    # print('num_val_we_use', num_val_we_use)
    # num_train_we_use = 8000
    # num_val_we_use = 800

    train_we_use_start_idx = train_useful_start_idx[0:num_train_we_use]  # 训练数据开始位置
    val_we_use_start_idx = val_useful_start_idx[0:num_val_we_use]

    np.random.seed(0)
    np.random.shuffle(train_we_use_start_idx)  # 将序列的所有元素随机排序
    train_idx = []
    for i in range(num_train_we_use):  # 训练集帧数
        for j in range(sequence_length):
            train_idx.append(train_we_use_start_idx[i] + j * srate)  # 训练数据位置,每一张图是一个数据
    # print('train_idx',train_idx)

    val_idx = []
    for i in range(num_val_we_use):
        for j in range(sequence_length):
            val_idx.append(val_we_use_start_idx[i] + j * srate)
    # print('val_idx',val_idx)

    num_train_all = float(len(train_idx))
    num_val_all = float(len(val_idx))
    print('num of train dataset: {:6d}'.format(num_train))
    print('num train start idx : {:6d}'.format(len(train_useful_start_idx)))
    print('last idx train start: {:6d}'.format(train_useful_start_idx[-1]))
    print('num of train we use : {:6d}'.format(num_train_we_use))
    print('num of all train use: {:6d}'.format(int(num_train_all)))
    print('num of valid dataset: {:6d}'.format(num_val))
    print('num valid start idx : {:6d}'.format(len(val_useful_start_idx)))
    print('last idx valid start: {:6d}'.format(val_useful_start_idx[-1]))
    print('num of valid we use : {:6d}'.format(num_val_we_use))
    print('num of all valid use: {:6d}'.format(int(num_val_all)))

    val_loader = DataLoader(
        val_dataset,
        batch_size=val_batch_size,
        # sampler=val_idx,
        sampler=SeqSampler(val_dataset, val_idx),
        num_workers=workers,
        pin_memory=False
    )
    model = res34_tcn()
    if use_gpu:
        model = model.cuda()

    model = DataParallel(model)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    # model.parameters()与model.state_dict()是Pytorch中用于查看网络参数的方法。前者多见于优化器的初始化,后者多见于模型的保存
    best_model_wts = copy.deepcopy(model.state_dict())
    best_val_accuracy = 0.0
    correspond_train_acc = 0.0

    record_np = np.zeros([epochs, 4])

    for epoch in range(epochs):
        np.random.seed(epoch)
        np.random.shuffle(train_we_use_start_idx)  # 将序列的所有元素随机排序
        train_idx = []
        for i in range(num_train_we_use):
            for j in range(sequence_length):
                train_idx.append(train_we_use_start_idx[i] + j * srate)

        train_loader = DataLoader(
            train_dataset,
            batch_size=train_batch_size,
            sampler=SeqSampler(train_dataset, train_idx),
            num_workers=workers,
            pin_memory=False
        )

        model.train()
        train_loss = 0.0
        train_corrects = 0
        train_start_time = time.time()
        num = 0
        train_num = 0
        for data in train_loader:
            num = num + 1
            # inputs, labels_phase = data
            inputs, labels_phase, kdata = data
            if use_gpu:
                inputs = Variable(inputs.cuda())  # Variable就是一个存放会变化值的地理位置,里面的值会不停发生变化
                labels = Variable(labels_phase.cuda())
                kdatas = Variable(kdata.cuda())
            else:
                inputs = Variable(inputs)
                labels = Variable(labels_phase)
                kdatas = Variable(kdata)
            optimizer.zero_grad()  # 梯度初始化为零,也就是把loss关于weight的导数变成0.
            # outputs = model.forward(inputs)  # 前向传播
            outputs = model.forward(inputs, kdatas)
            #outputs = F.softmax(outputs, dim=-1)
            _, preds = torch.max(outputs.data, -1)  # .data 获取Variable的内部Tensor;torch.max(a,1)返回每一行中最大值的那个元素,且返回其索引
            #_, yp = torch.max(y.data, 1)
            #print(yp)
            # print(yp.shape)
            print(num)
            print(preds)
            print(labels)


            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            train_loss += loss.data
            train_corrects += torch.sum(preds == labels.data)
            train_num += labels.shape[0]
            print(train_corrects.cpu().numpy() / train_num)
            if train_corrects.cpu().numpy() / train_num > 0.75:
                torch.save(copy.deepcopy(model.state_dict()), 'test.pth')  # .state_dict()只保存网络中的参数(速度快,占内存少)

        train_elapsed_time = time.time() - train_start_time

        #train_accuracy1 = train_corrects1.cpu().numpy() / train_num
        train_accuracy = train_corrects.cpu().numpy() / train_num
        train_average_loss = train_loss / train_num

        # begin eval
        model.eval()
        val_loss = 0.0
        val_corrects = 0
        val_num = 0
        val_start_time = time.time()
        for data in val_loader:
            inputs, labels_phase, kdata = data
            #inputs, labels_phase = data
            #labels_phase = labels_phase[(sequence_length - 1)::sequence_length]
            #kdata = kdata[(sequence_length - 1)::sequence_length]
            if use_gpu:
                inputs = Variable(inputs.cuda())
                labels = Variable(labels_phase.cuda())
                kdatas = Variable(kdata.cuda())
            else:
                inputs = Variable(inputs)
                labels = Variable(labels_phase)
                kdatas = Variable(kdata)

            if crop_type == 0 or crop_type == 1:
                #outputs = model.forward(inputs)
                outputs = model.forward(inputs, kdatas)
            elif crop_type == 5:
                inputs = inputs.permute(1, 0, 2, 3, 4).contiguous()
                inputs = inputs.view(-1, 3, 224, 224)
                outputs = model.forward(inputs, kdatas)
                # outputs = model.forward(inputs)
                outputs = outputs.view(5, -1, 3)
                outputs = torch.mean(outputs, 0)
            elif crop_type == 10:
                inputs = inputs.permute(1, 0, 2, 3, 4).contiguous()
                inputs = inputs.view(-1, 3, 224, 224)
                outputs = model.forward(inputs, kdatas)
                #outputs = model.forward(inputs)
                outputs = outputs.view(10, -1, 3)
                outputs = torch.mean(outputs, 0)

            #outputs = outputs[sequence_length - 1::sequence_length]

            _, preds = torch.max(outputs.data, -1)
            #_, yp = torch.max(y.data, 1)
            print(num)
            print(preds)
            print(labels)


            loss = criterion(outputs, labels)
            #loss = 0.05 * loss1 + 0.15 * loss2 + 0.3 * loss3 + 0.5 * loss4
            #loss = 0.05 * loss1 + 0.1 * loss2 + 0.25 * loss3 + 0.6 * loss4
            val_loss += loss.data
            val_corrects += torch.sum(preds == labels.data)
            val_num += labels.shape[0]
        val_elapsed_time = time.time() - val_start_time
        val_accuracy = val_corrects.cpu().numpy() / val_num
        val_average_loss = val_loss / val_num
        print('epoch: {:4d}'
              ' train in: {:2.0f}m{:2.0f}s'
              ' train loss: {:4.4f}'
              ' train accu: {:.4f}'
              ' valid in: {:2.0f}m{:2.0f}s'
              ' valid loss: {:4.4f}'
              ' valid accu: {:.4f}'
              .format(epoch,
                      train_elapsed_time // 60,
                      train_elapsed_time % 60,
                      train_average_loss,
                      train_accuracy,
                      val_elapsed_time // 60,
                      val_elapsed_time % 60,
                      val_average_loss,
                      val_accuracy))

        if optimizer_choice == 0:
            if sgd_adjust_lr == 0:
                exp_lr_scheduler.step()
            elif sgd_adjust_lr == 1:
                exp_lr_scheduler.step(val_average_loss)

        if val_accuracy > best_val_accuracy:
            best_val_accuracy = val_accuracy
            correspond_train_acc = train_accuracy
            best_model_wts = copy.deepcopy(model.state_dict())
        if val_accuracy == best_val_accuracy:
            if train_accuracy > correspond_train_acc:
                correspond_train_acc = train_accuracy
                best_model_wts = copy.deepcopy(model.state_dict())


        record_np[epoch, 0] = train_accuracy
        record_np[epoch, 1] = train_average_loss
        record_np[epoch, 2] = val_accuracy
        record_np[epoch, 3] = val_average_loss
        np.save(str(epoch) + '.npy', record_np)

    print('best accuracy: {:.4f} cor train accu: {:.4f}'.format(best_val_accuracy, correspond_train_acc))

    save_val = int("{:4.0f}".format(best_val_accuracy * 10000))
    save_train = int("{:4.0f}".format(correspond_train_acc * 10000))
    model_name = "tcn" \
                 + "_epoch_" + str(epochs) \
                 + "_length_" + str(sequence_length) \
                 + "_opt_" + str(optimizer_choice) \
                 + "_mulopt_" + str(multi_optim) \
                 + "_flip_" + str(use_flip) \
                 + "_crop_" + str(crop_type) \
                 + "_batch_" + str(train_batch_size) \
                 + "_train_" + str(save_train) \
                 + "_val_" + str(save_val) \
                 + ".pth"

    torch.save(best_model_wts, model_name)

    record_name = "tcn" \
                  + "_epoch_" + str(epochs) \
                  + "_length_" + str(sequence_length) \
                  + "_opt_" + str(optimizer_choice) \
                  + "_mulopt_" + str(multi_optim) \
                  + "_flip_" + str(use_flip) \
                  + "_crop_" + str(crop_type) \
                  + "_batch_" + str(train_batch_size) \
                  + "_train_" + str(save_train) \
                  + "_val_" + str(save_val) \
                  + ".npy"
    np.save(record_name, record_np)
示例#31
0
def train(args):
    # Setup train DataLoader
    trainloader = CCFLoader(args.traindir, split=args.split,
                            is_transform=True, img_size=(args.img_rows, args.img_cols))
    n_classes = trainloader.n_classes
    TrainDataLoader = data.DataLoader(
        trainloader, batch_size=args.batch_size, num_workers=4, shuffle=True)

    # Setup validate DataLoader
    valloader = CCFLoader(args.traindir, split='val', is_transform=True, img_size=(
        args.img_rows, args.img_cols))
    VALDataLoader = data.DataLoader(
        valloader, batch_size=4, num_workers=4, shuffle=False)

    # Setup visdom for visualization
    vis = visdom.Visdom()
    assert vis.check_connection()

    loss_window = vis.line(X=np.zeros((1,)),
                           Y=np.zeros((1)),
                           opts=dict(xlabel='minibatches',
                                     ylabel='Loss',
                                     title=args.arch+' Training Loss',
                                     legend=['Loss']))
    valacc_window = vis.line(X=np.zeros((1,)),
                             Y=np.zeros((1)),
                             opts=dict(xlabel='minibatches',
                                       ylabel='ACC',
                                       title='Val ACC',
                                       legend=['ACC']))

    # Setup model
    if(args.snapshot == None):
        model = get_model(args.arch, n_classes)
        model = DataParallel(model.cuda(args.gpu[0]), device_ids=args.gpu)
        start_epoch = 0
    else:
        model = get_model(args.arch, n_classes)
        state_dict = torch.load(args.snapshot).state_dict()
        from collections import OrderedDict
        new_state_dict = OrderedDict()
        for key, value in list(state_dict.items()):
            original_key = key[7:] # remove 'moudle.'
            new_state_dict[original_key] = value
        model.load_state_dict(new_state_dict)
        model = DataParallel(model.cuda(), device_ids=[i for i in range(len(args.gpu))])
        start_epoch = int(os.path.basename(args.snapshot).split('.')[0])

    optimizer = torch.optim.SGD(model.parameters(), lr=args.l_rate, momentum=0.99, weight_decay=5e-4)

    print(model)

    # Start training
    for epoch in range(args.n_epoch):
        adjust_learning_rate(optimizer, args.l_rate, epoch, args.step)
        if(epoch < start_epoch):
            continue
        print("Epoch [%d/%d] learning rate: %f" % (epoch+1, args.n_epoch, optimizer.param_groups[0]['lr']))
        for i, (images, labels) in enumerate(TrainDataLoader):
            if torch.cuda.is_available():
                images = Variable(images.cuda(args.gpu[0]))
                labels = Variable(labels.cuda(args.gpu[0]))
            else:
                images = Variable(images)
                labels = Variable(labels)

            iter = len(TrainDataLoader)*epoch + i
            #poly_lr_scheduler(optimizer, args.l_rate, iter)

            model.train()
            optimizer.zero_grad()
            outputs = model(images)
            if(isinstance(outputs, tuple)):
                loss = cross_entropy2d(outputs[0], labels, weights_per_class) + args.clsloss_weight * bin_clsloss(outputs[1], labels)
            else:
                #loss = cross_entropy2d(outputs, labels)
                loss = cross_entropy2d(outputs, labels, weights_per_class)
                #loss = focal_loss2d(outputs, labels)

            loss.backward()
            optimizer.step()

            vis.line(
                X=torch.ones((1, 1)).cpu()*iter,
                Y=torch.Tensor([loss.data]).unsqueeze(0).cpu(),
                win=loss_window,
                update='append')

        print("Epoch [%d/%d] loss: %f" % (epoch+1, args.n_epoch, loss))

        # validation
        loss, score = validate(model, VALDataLoader, n_classes)
        for i in range(n_classes):
            print(i, score['Class Acc'][i])
        vis.line(
            X=torch.ones((1, 1)).cpu()*(epoch+1),
            Y=torch.ones((1, 1)).cpu()*score['Overall Acc'],
            win=valacc_window,
            update='append')

        if(not os.path.exists("snapshot/{}".format(args.arch))):
            os.mkdir("snapshot/{}".format(args.arch))
        torch.save(model, "snapshot/{}/{}.pkl".format(args.arch, epoch+1))