Esempio n. 1
0
    def get_log_marginal_density(loader):
        model.eval()
        meter = AverageMeter()
        pbar = tqdm(total=len(loader))

        with torch.no_grad():
            for _, response, _, mask in loader:
                mb = response.size(0)
                response = response.to(device)
                mask = mask.long().to(device)

                posterior = Importance(
                    model.model,
                    guide=model.guide,
                    num_samples=args.num_posterior_samples,
                )
                posterior = posterior.run(response, mask)
                log_weights = torch.stack(posterior.log_weights)
                marginal = torch.logsumexp(log_weights, 0) - math.log(
                    log_weights.size(0))
                meter.update(marginal.item(), mb)

                pbar.update()
                pbar.set_postfix({'Marginal': meter.avg})

        pbar.close()
        print('====> Marginal: {:.4f}'.format(meter.avg))

        return meter.avg
Esempio n. 2
0
    def meta_val(self, model, meta_val_way, meta_val_shot, disable_tqdm,
                 callback, epoch):
        top1 = AverageMeter()
        model.eval()

        with torch.no_grad():
            tqdm_test_loader = warp_tqdm(self.val_loader, disable_tqdm)
            for i, (inputs, target, _) in enumerate(tqdm_test_loader):
                inputs, target = inputs.to(self.device), target.to(
                    self.device, non_blocking=True)
                output = model(inputs, feature=True)[0].cuda(0)
                train_out = output[:meta_val_way * meta_val_shot]
                train_label = target[:meta_val_way * meta_val_shot]
                test_out = output[meta_val_way * meta_val_shot:]
                test_label = target[meta_val_way * meta_val_shot:]
                train_out = train_out.reshape(meta_val_way, meta_val_shot,
                                              -1).mean(1)
                train_label = train_label[::meta_val_shot]
                prediction = self.metric_prediction(train_out, test_out,
                                                    train_label)
                acc = (prediction == test_label).float().mean()
                top1.update(acc.item())
                if not disable_tqdm:
                    tqdm_test_loader.set_description('Acc {:.2f}'.format(
                        top1.avg * 100))

        if callback is not None:
            callback.scalar('val_acc', epoch + 1, top1.avg, title='Val acc')
        return top1.avg
    def train(epoch):
        model.train()
        train_loss = AverageMeter()
        pbar = tqdm(total=len(train_loader))

        for batch_idx, (index, response, _, mask) in enumerate(train_loader):
            mb = response.size(0)
            index = index.to(device)
            response = response.to(device)
            mask = mask.long().to(device)
        
            optimizer.zero_grad()
            response_mu = model(index, response, mask)
            loss = F.binary_cross_entropy(response_mu, response.float(), reduction='none')
            loss = loss * mask
            loss = loss.mean()
            loss.backward()
            optimizer.step()

            train_loss.update(loss.item(), mb)

            pbar.update()
            pbar.set_postfix({'Loss': train_loss.avg})

        pbar.close()
        print('====> Train Epoch: {} Loss: {:.4f}'.format(epoch, train_loss.avg))

        return train_loss.avg
    def get_log_marginal_density(loader):
        model.eval()
        meter = AverageMeter()
        pbar = tqdm(total=len(loader))

        with torch.no_grad():
            for _, response, _, mask in loader:
                mb = response.size(0)
                response = response.to(device)
                mask = mask.long().to(device)

                marginal = model.log_marginal(
                    response, 
                    mask, 
                    num_samples = args.num_posterior_samples,
                )
                marginal = torch.mean(marginal)
                meter.update(marginal.item(), mb)

                pbar.update()
                pbar.set_postfix({'Marginal': meter.avg})
        
        pbar.close()
        print('====> Marginal: {:.4f}'.format(meter.avg))

        return meter.avg
Esempio n. 5
0
    def train(epoch):
        model.train()
        train_loss = AverageMeter()
        pbar = tqdm(total=len(train_loader))

        for batch_idx, (index, response, _, mask) in enumerate(train_loader):
            mb = response.size(0)
            index = index.to(device)
            response = response.to(device)
            mask = mask.long().to(device)
            annealing_factor = get_annealing_factor(epoch, batch_idx)
        
            optimizer.zero_grad()
            outputs = model(index, response, mask)
            loss = model.elbo(*outputs, annealing_factor=annealing_factor)
            loss.backward()
            optimizer.step()

            train_loss.update(loss.item(), mb)

            pbar.update()
            pbar.set_postfix({'Loss': train_loss.avg})

        pbar.close()
        print('====> Train Epoch: {} Loss: {:.4f}'.format(epoch, train_loss.avg))

        return train_loss.avg
Esempio n. 6
0
def val(val_loader, model):
    val_nmi = AverageMeter()
    model.eval()

    start_idx = 0
    with torch.no_grad():
        for it, (idx, inputs, labels) in enumerate(val_loader):

            # ============ multi-res forward passes ... ============
            emb, output = model(inputs)
            emb = emb.detach()
            bs = inputs[0].size(0)

            # ============ deepcluster-v2 val nmi ... ============
            nmi = 0
            for h in range(len(args.nmb_prototypes)):
                scores = output[h] / args.temperature
                _, cluster_assignments = scores.max(1)
                nmi += normalized_mutual_info_score(
                    labels.repeat(sum(args.nmb_crops)).cpu().numpy(),
                    cluster_assignments.cpu().numpy())
            nmi /= len(args.nmb_prototypes)

            # ============ misc ... ============
            val_nmi.update(nmi)

    return val_nmi.avg
Esempio n. 7
0
    def __init__(self, env):
        # game params
        self.board_x, self.board_y = env.get_ub_board_size()
        self.action_size = env.n_actions
        self.n_inputs = env.n_inputs
        self.lr = args.lr
        self.env = env
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'

        super(Policy, self).__init__()
        self.conv1 = nn.Conv2d(self.n_inputs,
                               args.num_channels,
                               3,
                               stride=1,
                               padding=1).to(self.device)
        self.conv2 = nn.Conv2d(args.num_channels,
                               args.num_channels,
                               3,
                               stride=1,
                               padding=1).to(self.device)
        self.conv3 = nn.Conv2d(args.num_channels,
                               args.num_channels,
                               3,
                               stride=1).to(self.device)
        self.conv4 = nn.Conv2d(args.num_channels,
                               args.num_channels,
                               3,
                               stride=1).to(self.device)

        self.bn1 = nn.BatchNorm2d(args.num_channels).to(self.device)
        self.bn2 = nn.BatchNorm2d(args.num_channels).to(self.device)
        self.bn3 = nn.BatchNorm2d(args.num_channels).to(self.device)
        self.bn4 = nn.BatchNorm2d(args.num_channels).to(self.device)
        self.fc1 = nn.Linear(args.num_channels*(self.board_x - 4)*(self.board_y - 4) \
                             + env.agent_step_dim, 1024).to(self.device)
        self.fc_bn1 = nn.BatchNorm1d(1024).to(self.device)

        self.fc2 = nn.Linear(1024, 512).to(self.device)
        self.fc_bn2 = nn.BatchNorm1d(512).to(self.device)

        self.fc3 = nn.Linear(512, self.action_size).to(self.device)

        self.fc4 = nn.Linear(512, 1).to(self.device)

        self.entropies = 0
        self.pi_losses = AverageMeter()
        self.v_losses = AverageMeter()
        self.action_probs = [[], []]
        self.state_values = [[], []]
        self.rewards = [[], []]
        self.next_states = [[], []]
        if args.optimizer == 'adas':
            self.optimizer = Adas(self.parameters(), lr=self.lr)
        elif args.optimizer == 'adam':
            self.optimizer = Adam(self.parameters(), lr=self.lr)
        else:
            self.optimizer = SGD(self.parameters(), lr=self.lr)
    def test(epoch):
        model.eval()
        test_loss = AverageMeter()
        pbar = tqdm(total=len(test_loader))

        with torch.no_grad():
            for _, response, _, mask in test_loader:
                mb = response.size(0)
                response = response.to(device)
                mask = mask.long().to(device)

                if args.n_norm_flows > 0:
                    (
                        response,
                        mask,
                        response_mu,
                        ability_k,
                        ability,
                        ability_mu,
                        ability_logvar,
                        ability_logabsdetjac,
                        item_feat_k,
                        item_feat,
                        item_feat_mu,
                        item_feat_logvar,
                        item_feat_logabsdetjac,
                    ) = model(response, mask)
                    loss = model.elbo(
                        response,
                        mask,
                        response_mu,
                        ability,
                        ability_mu,
                        ability_logvar,
                        item_feat,
                        item_feat_mu,
                        item_feat_logvar,
                        use_kl_divergence=False,
                        ability_k=ability_k,
                        item_feat_k=item_feat_k,
                        ability_logabsdetjac=ability_logabsdetjac,
                        item_logabsdetjac=item_feat_logabsdetjac,
                    )
                else:
                    outputs = model(response, mask)
                    loss = model.elbo(*outputs)
                test_loss.update(loss.item(), mb)

                pbar.update()
                pbar.set_postfix({'Loss': test_loss.avg})

        pbar.close()
        print('====> Test Epoch: {} Loss: {:.4f}'.format(epoch, test_loss.avg))

        return test_loss.avg
    def train(epoch):
        model.train()
        train_loss = AverageMeter()
        pbar = tqdm(total=len(train_loader))

        for batch_idx, (_, response, _, _) in enumerate(train_loader):
            mb = response.size(0)
            item_index = torch.arange(num_item).to(device)
            response = response.to(device)

            if mb != args.batch_size:
                pbar.update()
                continue

            with torch.no_grad():
                item_index = item_index.unsqueeze(0).repeat(mb, 1)
                item_index[(response == -1).squeeze(2)] = -1

                # build what dkvmn_irt expects
                q_data = item_index.clone()
                a_data = response.clone().squeeze(2)
                # ??? https://github.com/ckyeungac/DeepIRT/blob/master/load_data.py
                qa_data = q_data + a_data * num_item
                qa_data[(response == -1).squeeze(2)] = -1

                # map q_data and qa_data to 0 to N+1
                q_data = q_data + 1
                qa_data = qa_data + 1
                label = response.clone().squeeze(2)

            optimizer.zero_grad()
            pred_zs, student_abilities, question_difficulties = \
                model(q_data, qa_data, label)
            loss = model.get_loss(
                pred_zs,
                student_abilities,
                question_difficulties,
                label,
            )
            loss.backward()
            # https://github.com/ckyeungac/DeepIRT/blob/master/configs.py
            nn.utils.clip_grad_norm(model.parameters(), args.max_grad_norm)
            optimizer.step()

            train_loss.update(loss.item(), mb)

            pbar.update()
            pbar.set_postfix({'Loss': train_loss.avg})

        pbar.close()
        print('====> Train Epoch: {} Loss: {:.4f}'.format(
            epoch, train_loss.avg))

        return train_loss.avg
Esempio n. 10
0
def save_json(args, model, reglog, optimizer, loader):
    pred_label = []
    log_top1 = AverageMeter()

    for iter_epoch, (inp, target) in enumerate(loader):
        # measure data loading time

        learning_rate_decay(optimizer, len(loader) * args.epoch + iter_epoch, args.lr)

        # start at iter start_iter
        if iter_epoch < args.start_iter:
            continue

        # move to gpu
        inp = inp.cuda(non_blocking=True)
        target = target.cuda(non_blocking=True)
        if 'VOC2007' in args.data_path:
            target = target.float()

        # forward
        with torch.no_grad():
            output = model(inp)

        output = reglog(output)
        _, pred = output.topk(1, 1, True, True)
        pred = pred.t()

        pred_var = pred.data.cpu().numpy().reshape(-1) 
        for i in range(len(pred_var)):
            pred_label.append(pred_var[i])
  
        prec1 = accuracy(args, output, target)
        log_top1.update(prec1.item(), output.size(0)) 


    def load_json(file_path):
        assert os.path.exists(file_path), "{} does not exist".format(file_path)
        with open(file_path, 'r') as fp:
            data = json.load(fp)
        img_names = list(data.keys())
        return img_names
    
    json_predictions,img_names = {}, []
    img_names = load_json('./val_targets.json')

    for idx in range(len(pred_label)):
        json_predictions[img_names[idx]] = int(pred_label[idx])
    output_file = os.path.join(args.json_save_path, args.json_save_name)
 
    with open(output_file, 'w') as fp:
        json.dump(json_predictions, fp)   

    return log_top1.avg
Esempio n. 11
0
def val(val_loader, model, queue):
    norm_mut_info = AverageMeter()
    use_the_queue = False

    model.eval()
    end = time.time()
    with torch.no_grad():
        for it, (inputs, labels) in enumerate(val_loader):
            # normalize the prototypes
            with torch.no_grad():
                w = model.module.prototypes.weight.data.clone()
                w = nn.functional.normalize(w, dim=1, p=2)
                model.module.prototypes.weight.copy_(w)

            # ============ multi-res forward passes ... ============
            embedding, output = model(inputs)
            embedding = embedding.detach()
            bs = inputs[0].size(0)

            # ============ swav loss ... ============
            loss = 0
            for i, crop_id in enumerate(args.crops_for_assign):
                with torch.no_grad():
                    out = output[bs * crop_id:bs * (crop_id + 1)].detach()

                    # time to use the queue
                    if queue is not None:
                        if use_the_queue or not torch.all(queue[i,
                                                                -1, :] == 0):
                            use_the_queue = True
                            out = torch.cat(
                                (torch.mm(queue[i],
                                          model.module.prototypes.weight.t()),
                                 out))
                        # fill the queue
                        queue[i, bs:] = queue[i, :-bs].clone()
                        queue[i, :bs] = embedding[crop_id * bs:(crop_id + 1) *
                                                  bs]

                    # get assignments
                    q = distributed_sinkhorn(out)[-bs:]

            score, cluster_assignments = q.max(1)
            cluster_assignments = cluster_assignments.cpu().numpy()
            nmi = normalized_mutual_info_score(labels.cpu().numpy(),
                                               cluster_assignments)

            # ============ misc ... ============
            norm_mut_info.update(nmi)

    return norm_mut_info.avg
Esempio n. 12
0
def validate_with_softmax(val_loader, model, criterion, epoch, writer=None, threshold=0.5):

    # switch to evaluate mode
    model.eval()

    losses = AverageMeter('Loss', ":.4e")
    top1 = AverageMeter('Acc@1', ':6.2f')

    pbar = tqdm(val_loader)
    with torch.no_grad():
        for i, (images, target) in enumerate(pbar):
            if torch.cuda.is_available():
                images = images.cuda()
                target = target.cuda()
        
            # compute output
            output = model(images)
            loss = criterion(output, target)

            acc1 = accuracy(output, target, topk=(1,))
            losses.update(loss.item(), images.size(0))
            top1.update(acc1[0][0], images.size(0))

            pbar.set_description('Validation')
            
        print(" * Acc@1 {top1.avg:.3f}".format(top1=top1))
        if writer:
            writer.add_scalar('Test/Loss', losses.avg, epoch)
            writer.add_scalar('Test/Top1_acc', top1.avg, epoch)

    return top1.avg
Esempio n. 13
0
def test(model, criterion, test_loader, run_config):
    device = torch.device(run_config['device'])

    model.eval()

    loss_meter = AverageMeter()
    correct_meter = AverageMeter()
    start = time.time()
    with torch.no_grad():
        for step, (data, targets) in enumerate(test_loader):
            data = data.to(device)
            targets = targets.to(device)

            outputs = model(data)
            loss = criterion(outputs, targets)

            _, preds = torch.max(outputs, dim=1)

            loss_ = loss.item()
            correct_ = preds.eq(targets).sum().item()
            num = data.size(0)

            loss_meter.update(loss_, num)
            correct_meter.update(correct_, 1)

        accuracy = correct_meter.sum / len(test_loader.dataset)

        elapsed = time.time() - start

    test_log = collections.OrderedDict({
        'loss': loss_meter.avg,
        'accuracy': accuracy,
        'time': elapsed
    })
    return test_log
    def test(epoch):
        model.eval()
        test_loss = AverageMeter()
        pbar = tqdm(total=len(test_loader))

        with torch.no_grad():
            for _, response, _, _ in test_loader:
                mb = response.size(0)
                item_index = torch.arange(num_item).to(device)
                response = response.to(device)

                if mb != args.batch_size:
                    pbar.update()
                    continue

                with torch.no_grad():
                    item_index = item_index.unsqueeze(0).repeat(mb, 1)
                    item_index[(response == -1).squeeze(2)] = -1

                    # build what dkvmn_irt expects
                    q_data = item_index.clone()
                    a_data = response.clone().squeeze(2)
                    # ??? https://github.com/ckyeungac/DeepIRT/blob/master/load_data.py
                    qa_data = q_data + a_data * num_item
                    qa_data[(response == -1).squeeze(2)] = -1

                    # map q_data and qa_data to 0 to N+1
                    q_data = q_data + 1
                    qa_data = qa_data + 1
                    label = response.clone().squeeze(2)

                pred_zs, student_abilities, question_difficulties = \
                    model(q_data, qa_data, label)
                loss = model.get_loss(
                    pred_zs,
                    student_abilities,
                    question_difficulties,
                    label,
                )
                test_loss.update(loss.item(), mb)

                pbar.update()
                pbar.set_postfix({'Loss': test_loss.avg})

        pbar.close()
        print('====> Test Epoch: {} Loss: {:.4f}'.format(epoch, test_loss.avg))

        return test_loss.avg
Esempio n. 15
0
    def val(epoch):
        model.eval()
        loss_meters = [AverageMeter() for _ in range(n_planes)]

        with torch.no_grad():
            for i in range(args.n_train_models):
                val_loader = val_loaders[i]

                for x_i, _ in val_loader:
                    batch_size = len(x_i)
                    x_i = x_i.to(device)

                    context_x_i, context_z_i = sample_minibatch(
                        val_loader.dataset, batch_size, args.n_mlp_samples)
                    context_x_i = context_x_i.to(device)
                    context_z_i = context_z_i.to(device)
                    context_x_z_i = torch.cat([context_x_i, context_z_i],
                                              dim=2)

                    z_mu_i, z_logvar_i = model(x_i, context_x_z_i)
                    loss_i = compiled_inference_objective(
                        z_i, z_mu_i, z_logvar_i)

                    loss_meters[i].update(loss_i.item(), batch_size)

        loss_meter_avgs = [meter.avg for meter in loss_meters]
        loss_meter_avgs = np.array(loss_meter_avgs)

        print('====> Test Epoch: {}\tAverage Loss: {:.4f}'.format(
            epoch, np.mean(loss_meter_avgs)))

        return loss_meter_avgs
    def sample_posterior_mean(loader):
        model.eval()
        meter = AverageMeter()
        pbar = tqdm(total=len(loader))

        with torch.no_grad():
            
            response_sample_set = []

            for _, response, _, mask in loader:
                mb = response.size(0)
                response = response.to(device)
                mask = mask.long().to(device)

                _, ability_mu, _, _, item_feat_mu, _ = \
                    model.encode(response, mask)
                
                response_sample = model.decode(ability_mu, item_feat_mu).cpu()
                response_sample_set.append(response_sample.unsqueeze(0))

                pbar.update()

            response_sample_set = torch.cat(response_sample_set, dim=1)

            pbar.close()

        return {'response': response_sample_set}
    def train(epoch):
        model.train()
        train_loss = AverageMeter()
        pbar = tqdm(total=len(train_loader))

        for batch_idx, (_, response, _, mask) in enumerate(train_loader):
            mb = response.size(0)
            response = response.to(device)
            mask = mask.long().to(device)
            annealing_factor = get_annealing_factor(epoch, batch_idx)
        
            optimizer.zero_grad()
            if args.n_norm_flows > 0:
                (
                    response, mask, response_mu, 
                    ability_k, ability, 
                    ability_mu, ability_logvar, ability_logabsdetjac, 
                    item_feat_k, item_feat, 
                    item_feat_mu, item_feat_logvar, item_feat_logabsdetjac,
                ) = model(response, mask)
                loss = model.elbo(
                    response, mask, response_mu, 
                    ability, ability_mu, ability_logvar,
                    item_feat, item_feat_mu, item_feat_logvar, 
                    annealing_factor = annealing_factor,
                    use_kl_divergence = False,
                    ability_k = ability_k,
                    item_feat_k = item_feat_k,
                    ability_logabsdetjac = ability_logabsdetjac,
                    item_logabsdetjac = item_feat_logabsdetjac,
                )
            else:
                outputs = model(response, mask)
                loss = model.elbo(*outputs, annealing_factor=annealing_factor,
                                use_kl_divergence=True)
            loss.backward()
            optimizer.step()

            train_loss.update(loss.item(), mb)

            pbar.update()
            pbar.set_postfix({'Loss': train_loss.avg})

        pbar.close()
        print('====> Train Epoch: {} Loss: {:.4f}'.format(epoch, train_loss.avg))

        return train_loss.avg
Esempio n. 18
0
    def val(epoch):
        model.eval()
        loss_meter = AverageMeter()

        with torch.no_grad():
            for data in val_loader:
                batch_size = data.size(0)
                data = data.to(device)

                z_mu, z_logvar = model(data)
                loss = compiled_inference_objective(z, z_mu, z_logvar)

                loss_meter.update(loss.item(), batch_size)

        print('====> Test Epoch: {}\tLoss: {:.4f}'.format(
            epoch, loss_meter.avg))
        return loss_meter.avg
Esempio n. 19
0
    def test(epoch):
        model.eval()
        test_loss = AverageMeter()
        pbar = tqdm(total=len(test_loader))

        with torch.no_grad():
            for _, response, _, mask in test_loader:
                mb = response.size(0)
                response = response.to(device)
                mask = mask.long().to(device)

                loss = svi.evaluate_loss(response, mask)
                test_loss.update(loss, mb)

                pbar.update()
                pbar.set_postfix({'Loss': test_loss.avg})

        pbar.close()
        print('====> Test Epoch: {} Loss: {:.4f}'.format(epoch, test_loss.avg))

        return test_loss.avg
Esempio n. 20
0
    def train(epoch):
        model.train()
        train_loss = AverageMeter()
        pbar = tqdm(total=len(train_loader))

        for batch_idx, (_, response, _, mask) in enumerate(train_loader):
            mb = response.size(0)
            response = response.to(device)
            mask = mask.long().to(device)
            annealing_factor = get_annealing_factor(epoch, batch_idx)

            loss = svi.step(response, mask, annealing_factor)
            train_loss.update(loss, mb)

            pbar.update()
            pbar.set_postfix({'Loss': train_loss.avg})

        pbar.close()
        print('====> Train Epoch: {} Loss: {:.4f}'.format(
            epoch, train_loss.avg))

        return train_loss.avg
Esempio n. 21
0
    def test(self):
        num_batches = self.test_len // self.config.optim_params.batch_size
        tqdm_batch = tqdm(total=num_batches,
                          desc="[Epoch {}]".format(self.current_epoch))

        self.model.eval()
        epoch_losses = [AverageMeter() for _ in range(self.n_test_datasets)]

        with torch.no_grad():
            for batch_i, data_list in enumerate(self.test_loader):
                if self.config.model_params.name == 'ns':
                    batch_size = data_list[0][0].size(0)
                    data = torch.stack([x[0] for x in data_list])
                    data = data.to(self.device)
                    out = self.model(data)
                    loss = self.model.estimate_marginal(data, n_samples=10)
                    epoch_losses[0].update(loss.item(), batch_size)
                else:
                    for i in range(self.n_test_datasets):
                        data_i, _ = data_list[i]
                        data_i = data_i.float()
                        data_i = data_i.to(self.device)
                        batch_size = len(data_i)
                        sample_list_i = sample_minibatch_from_cache(
                            self.test_caches[i], batch_size,
                            self.config.model_params.n_data_samples)
                        sample_list_i = sample_list_i.float()
                        sample_list_i = sample_list_i.to(self.device)
                        if self.config.model_params.name == 'meta':
                            loss = self.model.estimate_marginal(data_i,
                                                                sample_list_i,
                                                                i,
                                                                n_samples=10)
                        elif self.config.model_params.name == 'vhe':
                            loss = self.model.estimate_marginal(data_i,
                                                                sample_list_i,
                                                                n_samples=10)
                        elif self.config.model_params.name == 'vhe_vamprior':
                            loss = self.model.estimate_marginal(data_i,
                                                                sample_list_i,
                                                                n_samples=10)

                    epoch_losses[i].update(loss.item(), batch_size)
                tqdm_batch.update()

        self.current_val_iteration += 1
        self.current_val_loss = sum(meter.avg for meter in epoch_losses)
        if self.current_val_loss < self.best_val_loss:
            self.best_val_loss = self.current_val_loss
        self.val_losses.append(self.current_val_loss)
        tqdm_batch.close()
    def test(epoch):
        model.eval()
        test_loss = AverageMeter()
        pbar = tqdm(total=len(test_loader))

        with torch.no_grad():
            for index, response, _, mask in test_loader:
                mb = response.size(0)
                index = index.to(device)
                response = response.to(device)
                mask = mask.long().to(device)

                response_mu = model(index, response, mask)
                loss = F.binary_cross_entropy(response_mu, response.float())
                test_loss.update(loss.item(), mb)

                pbar.update()
                pbar.set_postfix({'Loss': test_loss.avg})

        pbar.close()
        print('====> Test Epoch: {} Loss: {:.4f}'.format(epoch, test_loss.avg))

        return test_loss.avg
Esempio n. 23
0
    def train(epoch):
        model.train()
        loss_meter = AverageMeter()

        for batch_idx, data_list in enumerate(train_loader):
            x_list = [data[0] for data in data_list]
            batch_size = len(x_list[0])

            loss = 0
            for i in range(n_planes):
                x_i = x_list[i]
                x_i = x_i.to(device)

                context_x_i, context_z_i = sample_minibatch(
                    train_datasets[i], batch_size, args.n_mlp_samples)
                context_x_i = context_x_i.to(device)
                context_z_i = context_z_i.to(device)
                context_x_z_i = torch.cat([context_x_i, context_z_i], dim=2)

                z_mu_i, z_logvar_i = model(x_i, context_x_z_i)
                loss_i = compiled_inference_objective(z_i, z_mu_i, z_logvar_i)
                loss += loss_i

            loss_meter.update(loss.item(), batch_size)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if batch_idx % args.log_interval == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, batch_idx * batch_size, len(train_loader.dataset),
                    100. * batch_idx / len(train_loader), -loss_meter.avg))

        print('====> Train Epoch: {}\tLoss: {:.4f}'.format(
            epoch, -loss_meter.avg))
        return loss_meter.avg
Esempio n. 24
0
File: lars.py Progetto: zlapp/suncet
    def step(self):
        with torch.no_grad():
            stats = AverageMeter()
            weight_decays = []
            for group in self.optim.param_groups:

                # -- takes weight decay control from wrapped optimizer
                weight_decay = group[
                    'weight_decay'] if 'weight_decay' in group else 0
                weight_decays.append(weight_decay)

                # -- user wants to exclude this parameter group from LARS
                #    adaptation
                if ('LARS_exclude' in group) and group['LARS_exclude']:
                    continue
                group['weight_decay'] = 0

                for p in group['params']:
                    if p.grad is None:
                        continue
                    param_norm = torch.norm(p.data)
                    grad_norm = torch.norm(p.grad.data)

                    if param_norm != 0 and grad_norm != 0:
                        adaptive_lr = self.trust_coefficient * (param_norm) / (
                            grad_norm + param_norm * weight_decay + self.eps)

                        stats.update(adaptive_lr)
                        p.grad.data += weight_decay * p.data
                        p.grad.data *= adaptive_lr

        self.optim.step()
        # -- return weight decay control to wrapped optimizer
        for i, group in enumerate(self.optim.param_groups):
            group['weight_decay'] = weight_decays[i]

        return stats
Esempio n. 25
0
    def train(epoch):
        model.train()
        loss_meter = AverageMeter()

        for batch_idx, data in enumerate(train_loader):
            batch_size = data.size(0)
            data = data.to(device)

            z_mu, z_logvar = model(data)
            loss = compiled_inference_objective(z, z_mu, z_logvar)
            loss_meter.update(loss.item(), batch_size)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if batch_idx % args.log_interval == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, batch_idx * batch_size, len(train_loader.dataset),
                    100. * batch_idx / len(train_loader), -loss_meter.avg))

        print('====> Train Epoch: {}\tLoss: {:.4f}'.format(
            epoch, loss_meter.avg))
        return loss_meter.avg
Esempio n. 26
0
def train(model, optimizer, scheduler, criterion, train_loader, run_config):

    device = torch.device(run_config['device'])

    for param_group in optimizer.param_groups:
        current_lr = param_group['lr']

    model.train()

    loss_meter = AverageMeter()
    accuracy_meter = AverageMeter()
    start = time.time()

    for step, (data, targets) in enumerate(train_loader):

        if torch.cuda.device_count() == 1:
            data = data.to(device)
            targets = targets.to(device)

        optimizer.zero_grad()

        outputs = model(data)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        loss_ = loss.item()
        num = data.size(0)
        accuracy = utils.accuracy(outputs, targets)[0].item()

        loss_meter.update(loss_, num)
        accuracy_meter.update(accuracy, num)

        if scheduler is not None:
            scheduler.step()

    elapsed = time.time() - start

    train_log = collections.OrderedDict({
        'loss': loss_meter.avg,
        'accuracy': accuracy_meter.avg,
        'time': elapsed
    })
    return train_log
Esempio n. 27
0
    def sample_posterior_predictive(loader):
        # draw samples from the approximate posterior and send back
        model.eval()
        meter = AverageMeter()
        pbar = tqdm(total=len(loader))

        with torch.no_grad():

            response_sample_set = []

            for _, response, _, mask in loader:
                mb = response.size(0)
                response = response.to(device)
                mask = mask.long().to(device)

                ability_mu, ability_logvar, item_feat_mu, item_feat_logvar = \
                    model.guide(response, mask)

                ability_scale = torch.exp(0.5 * ability_logvar)
                item_feat_scale = torch.exp(0.5 * item_feat_logvar)

                ability_posterior = dist.Normal(ability_mu, ability_scale)
                item_feat_posterior = dist.Normal(item_feat_mu,
                                                  item_feat_scale)

                ability_samples = ability_posterior.sample(
                    [args.num_posterior_samples])
                item_feat_samples = item_feat_posterior.sample(
                    [args.num_posterior_samples])

                response_samples = []
                for i in range(args.num_posterior_samples):
                    ability_i = ability_samples[i]
                    item_feat_i = item_feat_samples[i]
                    response_i = model.generate(ability_i, item_feat_i).cpu()
                    response_samples.append(response_i)
                response_samples = torch.stack(response_samples)
                response_sample_set.append(response_samples)

                pbar.update()

            response_sample_set = torch.cat(response_sample_set, dim=1)

            pbar.close()

        return {'response': response_sample_set}
  def __init__(self,
               train_dataset, val_dataset, test_dataset,
               model, hyper_dict, experiment_name,
               device, cross_validation=False):
    self.train_dataset = train_dataset
    self.val_dataset = val_dataset
    self.test_dataset = test_dataset

    self.handler = LockableModelSaveHandler(self)

    self.model = model
    self.best_model = copy.deepcopy(model)
    self.best_val_loss = None

    self.epochs = hyper_dict['epochs']
    self.batch_size = hyper_dict['batch_size']
    self.num_workers = hyper_dict['num_workers']
    self.hyper_dict = hyper_dict

    self.experiment_name = experiment_name
    self.device = device
    self.cross_validation = cross_validation

    key_lst = ['time']
    for split in ('train', 'val', 'test'):
      for metric in ('loss', 'acc'):
        key_lst.append(f"{split}_{metric}")

    self.avg_meter = {key: AverageMeter() for key in key_lst}
    self.tag_str = {key: "" for key in key_lst}

    self.train_ldr = DataLoader(train_dataset, batch_size=self.batch_size,
                                num_workers=self.num_workers, shuffle=True)
    self.val_ldr = DataLoader(val_dataset, batch_size=self.batch_size,
                              num_workers=self.num_workers, shuffle=False)
    self.test_ldr = DataLoader(test_dataset, batch_size=self.batch_size,
                               num_workers=self.num_workers, shuffle=False)
    self.optimizer = get_optimizer(model.parameters(), hyper_dict)

    # state variables
    self.current_iter = 0
Esempio n. 29
0
def train(train_loader, model, optimizer, epoch, lr_schedule, queue, args):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()

    softmax = nn.Softmax(dim=1).cuda()
    model.train()

    end = time.time()
    for it, inputs in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)

        # update learning rate
        iteration = epoch * len(train_loader) + it
        for param_group in optimizer.param_groups:
            param_group["lr"] = lr_schedule[iteration]

        # normalize the prototypes
        with torch.no_grad():
            w = model.module.prototypes.weight.data.clone()
            w = nn.functional.normalize(w, dim=1, p=2)
            model.module.prototypes.weight.copy_(w)

        # ============ data split ===========
        inputs, target = inputs
        # ============ multi-res forward passes ... ============
        embedding, output = model(inputs)
        embedding = embedding.detach()
        bs = inputs[0].size(0)

        # ============ EMA class-wise feature vector ==========
        for b in range(bs):
            queue[target[b]] = queue[target[b]] * 0.99 + (
                embedding[b] + embedding[bs + b]) * 0.01 / 2
        queue = nn.functional.normalize(queue, dim=1, p=2)
        dist.all_reduce(queue)
        queue /= args.world_size
        queue = nn.functional.normalize(queue, dim=1, p=2)
        # ============ swav loss ... ============
        loss = 0

        with torch.no_grad():
            q = torch.mm(queue, model.module.prototypes.weight.t())
            q = q / args.epsilon
            if args.improve_numerical_stability:
                M = torch.max(q)
                dist.all_reduce(M, op=dist.ReduceOp.MAX)
                q -= M

            q = torch.exp(q).t()
            q = sinkhorn(q, args.sinkhorn_iterations)
        # q = distributed_sinkhorn(q, args.sinkhorn_iterations)

        # match q /w label (1000, num_p) --> (bsz, num_p)
        for b in range(bs):
            if b == 0:
                matched_q = q[target[b]].unsqueeze(0)
            else:
                matched_q = torch.cat([matched_q, q[target[b]].unsqueeze(0)],
                                      0)

        # cluster assignment prediction
        subloss = 0
        for v in np.arange(np.sum(args.nmb_crops)):
            p = softmax(output[bs * v:bs * (v + 1)] / args.temperature)
            subloss -= torch.mean(torch.sum(matched_q * torch.log(p), dim=1))
        loss += subloss / np.sum(args.nmb_crops)

        # ============ backward and optim step ... ============
        optimizer.zero_grad()
        if args.use_fp16:
            with apex.amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()
        # cancel some gradients
        if iteration < args.freeze_prototypes_niters:
            for name, p in model.named_parameters():
                if "prototypes" in name:
                    p.grad = None
        optimizer.step()

        # ============ misc ... ============
        losses.update(loss.item(), inputs[0].size(0))
        batch_time.update(time.time() - end)
        end = time.time()
        if args.rank == 0 and it % 50 == 0:
            logger.info("Epoch: [{0}][{1}]\t"
                        "Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t"
                        "Data {data_time.val:.3f} ({data_time.avg:.3f})\t"
                        "Loss {loss.val:.4f} ({loss.avg:.4f})\t"
                        "Lr: {lr:.4f}".format(
                            epoch,
                            it,
                            batch_time=batch_time,
                            data_time=data_time,
                            loss=losses,
                            lr=optimizer.optim.param_groups[0]["lr"],
                        ))
    return (epoch, losses.avg), queue
Esempio n. 30
0
def train_loc_model(model, data_loaders, optimizer, scheduler, seg_loss, num_epochs, weight_dir, snapshot_name, log_dir, best_score=0):

    writer = SummaryWriter(log_dir + 'localization')
    print('Tensorboard is recording into folder: ' + log_dir + 'localization')

    torch.cuda.empty_cache()

    for epoch in range(num_epochs):
        losses = AverageMeter()

        dices = AverageMeter()
        iterator = data_loaders['train']
        iterator = tqdm(iterator)
        model.train()
        for i, sample in enumerate(iterator):
            imgs = sample["img"].cuda(non_blocking=True)
            msks = sample["msk"].cuda(non_blocking=True)
        
            out = model(imgs)

            loss = seg_loss(out, msks)

            with torch.no_grad():
                _probs = torch.sigmoid(out[:, 0, ...])
                dice_sc = 1 - dice_round(_probs, msks[:, 0, ...])

            losses.update(loss.item(), imgs.size(0))

            dices.update(dice_sc, imgs.size(0))

            iterator.set_description("Epoch {}/{}, lr {:.7f}; Loss {loss.val:.4f} ({loss.avg:.4f}); Dice {dice.val:.4f} ({dice.avg:.4f})".format(
                    epoch, num_epochs, scheduler.get_lr()[-1], loss=losses, dice=dices))
            
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 0.999)
            optimizer.step()

            writer.add_scalar('Train/Loss', losses.avg, epoch)
            writer.add_scalar('Train/Dice', dices.avg, epoch)
            writer.flush()
        
        if epoch % 2 == 0:
            torch.cuda.empty_cache()

            model = model.eval()
            dices0 = []

            _thr = 0.5
            iterator = data_loaders['val']
            iterator = tqdm(iterator)
            with torch.no_grad():
                for i, sample in enumerate(iterator):
                    msks = sample["msk"].numpy()
                    imgs = sample["img"].cuda(non_blocking=True)
            
                    out = model(imgs)

                    msk_pred = torch.sigmoid(out[:, 0, ...]).cpu().numpy()
            
                    for j in range(msks.shape[0]):
                        dices0.append(dice(msks[j, 0], msk_pred[j] > _thr))

            d = np.mean(dices0)

            writer.add_scalar('Val/Dice', d, epoch)
            writer.flush()

            print("Val Dice: {}".format(d))

            if d > best_score:
                best_score = d
                torch.save({
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                    'best_score': d,
                }, path.join(weight_dir, snapshot_name + '_best'))

            print("score: {}\tscore_best: {}".format(d, best_score))

        writer.close()
            
    return best_score