def __init__(self, log_dir, model, optimizer, evaluator, device, params):
     self.model = model
     self.params = params
     self.device = device
     self.optimizer = optimizer
     self.evaluator = evaluator
     self.writer = SummaryWriter(log_dir)
     self.params = params
     self.nt_xent_criterion = NTXentLoss(self.device, params['batch_size'], **params['nce_loss'])
Exemple #2
0
 def _choose_loss(self):
     if self.config['loss_select'] == 'NT_Xent':
         print("using NT_Xent as loss func")
         return NTXentLoss(self.device, self.config['batch_size'], **self.config['loss'])
     elif self.config['loss_select'] == 'NT_Logistic':
         print("using NT_Logistic as loss func")
         return NTLogisticLoss(self.device, self.config['batch_size'], **self.config['loss'])
     elif self.config['loss_select'] == 'MarginTriplet':
         print("using MarginTriplet as loss func")
         return MarginTripletLoss(self.device, self.config['batch_size'], self.config['semi_hard'],
                                  **self.config['loss'])
     else:
         print('not a valid loss, use NT_Xent as default')
         return NTXentLoss(self.device, self.config['batch_size'], **self.config['loss'])
Exemple #3
0
 def __init__(self, dataset, config):
     self.config = config
     self.device = self._get_device()
     self.writer = SummaryWriter()
     self.dataset = dataset
     self.nt_xent_criterion = NTXentLoss(self.device, config['batch_size'],
                                         **config['loss'])
Exemple #4
0
 def __init__(self, dataset, config, config_path, exp_name):
     self.config = config
     self.config_path = config_path
     self.device = self._get_device()
     self.writer = SummaryWriter(log_dir=os.path.join('runs', exp_name))
     self.dataset = dataset
     self.nt_xent_criterion = NTXentLoss(self.device, config['batch_size'],
                                         **config['loss'])
Exemple #5
0
 def __init__(self, dataset, config):
     self.config = config
     #self.device = self._get_device()
     self.device = 'cuda'
     self.writer = SummaryWriter()
     self.dataset = dataset
     self.batch_size = config['train']['train_batch_size_per_gpu']
     if self.device == 'cuda':
         self.batch_size = self.batch_size * config['gpu']['gpunum']
     self.nt_xent_criterion = NTXentLoss(self.device, self.batch_size, **config['loss'])
Exemple #6
0
 def _make_loss(self):
     if (self.config['loss_type'] == 'nt_logistic'):
         return NTLogisticLoss(self.device, self.config['batch_size'],
                               **self.config['loss'])
     elif (self.config['loss_type'] == 'nt_xent'):
         return NTXentLoss(self.device, self.config['batch_size'],
                           **self.config['loss'])
     elif (self.config['loss_type'] == 'marginal_triplet'):
         return MarginalTripletLoss(self.device, self.config['batch_size'],
                                    **self.config['loss'])
Exemple #7
0
 def __init__(self, dataset, config):
     self.config = config
     self.device = self._get_device()
     self.writer = SummaryWriter()
     self.dataset = dataset
     self.nt_xent_criterion = NTXentLoss(self.device, config["batch_size"],
                                         **config["loss"])
     self.dis_criterion = torch.nn.CrossEntropyLoss()
     self.normal_dist = tdist.Normal(torch.Tensor([0.0]),
                                     torch.Tensor([1.0]))
     self.disc_weight = config["disc_weight"]
Exemple #8
0
 def __init__(self, dataset, config):
     self.config = config
     self.device = get_device()
     self.writer = SummaryWriter()
     self.dataset = dataset
     if config['loss_func'] == 'sim':
         self.nt_xent_criterion = NTXentLoss(self.device, config['batch_size'], **config['loss'])
     elif config['loss_func'] == 'siam':
         self.siam_loss = Siam(config['batch_size'])
     else:
         raise NotImplemented()
Exemple #9
0
 def __init__(self, dataset, config):
     self.config = config
     self.device = self._get_device()
     self.writer = SummaryWriter()
     self.dataset = dataset
     self.nt_xent_criterion = NTXentLoss(self.device, config['batch_size'],
                                         **config['loss'])
     self.truncation = config['truncation']
     self.tokenizer = AutoTokenizer.from_pretrained(
         config['model']['bert_base_model']
     )  #, do_lower_case=config['model_bert']['do_lower_case'])
Exemple #10
0
 def __init__(self, dataset, config, lumbda, Checkpoint_Num):
     self.lumbda1 = lumbda
     self.lumbda2 = lumbda
     self.config = config
     self.Checkpoint_Num = Checkpoint_Num
     print('\nThe configurations of this model are in the following:\n',
           config)
     self.device = self._get_device()
     # self.writer = SummaryWriter()
     self.dataset = dataset
     self.nt_xent_criterion = NTXentLoss(self.device, config['batch_size'],
                                         **config['loss'])
Exemple #11
0
 def __init__(self, dataset, config):
     self.config = config
     self.device = self._get_device()
     self.writer = SummaryWriter()
     self.dataset = dataset
     self.nt_xent_criterion = NTXentLoss(self.device, config["batch_size"],
                                         **config["loss"])
     self.normal_dist = tdist.Normal(torch.Tensor([0.0]),
                                     torch.Tensor([1.0]))
     self.augmentor_type = (config["augmentor_type"]
                            if "augmentor_type" in config else "cnn")
     self.augmentor_loss_type = (config["augmentor_loss_type"]
                                 if "augmentor_loss_type" in config else
                                 "linear")
Exemple #12
0
 def _get_loss_strategy(self,
                        device,
                        batch_size,
                        temperature,
                        use_cosine_similarity,
                        mode,
                        semi_hard='No'):
     if mode == 'nt-xent':
         print('The Training Loss is NT-Xent.')
         return NTXentLoss(device, batch_size, temperature,
                           use_cosine_similarity)
     elif mode == 'nt-logistic':
         print('The Training Loss is NT-Logistic')
         return NTLogisticLoss(device, batch_size, temperature,
                               use_cosine_similarity, semi_hard)
     elif mode == 'margin-triplet':
         print('The Training Loss is MarginTriplet')
         return MarginTripletLoss(device, batch_size, temperature,
                                  use_cosine_similarity, semi_hard)
     else:
         print("Unknown mode chosen,using default nt-xent instead.")
         return NTXentLoss(device, batch_size, temperature,
                           use_cosine_similarity)
Exemple #13
0
 def __init__(self, config):
     self.config = config
     self.device = self._get_device()
     self.train_loader = self._load_lvis_results()
     if self.config['loss']['type'] == 'nce':
         from loss.nt_xent import NTXentLoss
         self.loss_crit = NTXentLoss(self.device, config['batch_size'],
                                     **config['loss'])
     if self.config['loss']['include_hierarchical']:
         self.hierarchical_loss_crit = HierarchicalLoss(
             margin=config['loss']['margin'])
     if self.config['hyperbolic']:
         self.triplet_loss_crit = HTripletLoss(
             margin=config['loss']['margin'])
     else:
         self.triplet_loss_crit = TripletLoss(
             margin=config['loss']['margin'])
Exemple #14
0
def main_worker(gpu, ngpus_per_node, args):
    args.gpu = gpu

    # suppress printing if not master
    if args.multiprocessing_distributed and args.gpu != 0:

        def print_pass(*args):
            pass

        builtins.print = print_pass

    if args.gpu is not None:
        print(args.gpu)
        print("Use GPU: {} for training".format(args.gpu))

    if args.distributed:
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.multiprocessing_distributed:
            # For multiprocessing distributed training, rank needs to be the
            # global rank among all the processes
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size,
                                rank=args.rank)
    # create model
    print("=> creating model '{}'".format(args.arch))

    model = ResNetSimCLR(base_model=args.arch,
                         out_dim=args.out_dim).to(args.gpu)

    if args.distributed:
        # For multiprocessing distributed, DistributedDataParallel constructor
        # should always set the single device scope, otherwise,
        # DistributedDataParallel will use all available devices.
        if args.gpu is not None:
            torch.cuda.set_device(args.gpu)
            model.cuda(args.gpu)
            # When using a single GPU per process and per
            # DistributedDataParallel, we need to divide the batch size
            # ourselves based on the total number of GPUs we have
            args.batch_size = int(args.batch_size / ngpus_per_node)
            args.num_workers = int(
                (args.num_workers + ngpus_per_node - 1) / ngpus_per_node)
            model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
            model = torch.nn.parallel.DistributedDataParallel(
                model, device_ids=[args.gpu])
        else:
            model.cuda()
            # DistributedDataParallel will divide and allocate batch_size to all
            # available GPUs if device_ids are not set
            model = torch.nn.parallel.DistributedDataParallel(model)
    elif args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)
        # comment out the following line for debugging
        #raise NotImplementedError("Only DistributedDataParallel is supported.")
    #else:
    # AllGather implementation (batch shuffle, queue update, etc.) in
    # this code only supports DistributedDataParallel.
    #raise NotImplementedError("Only DistributedDataParallel is supported.")

    # Data loader
    train_loader, train_sampler = data_loader(args.dataset,
                                              args.data_path,
                                              args.batch_size,
                                              args.num_workers,
                                              download=args.download,
                                              distributed=args.distributed,
                                              supervised=False)

    #optimizer = torch.optim.Adam(model.parameters(), 3e-4, weight_decay=args.weight_decay)
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=args.lr,
                                momentum=0.9,
                                weight_decay=args.weight_decay)

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                           T_max=args.epochs,
                                                           eta_min=0,
                                                           last_epoch=-1)

    criterion = NTXentLoss(args.gpu, args.batch_size, args.temperature,
                           True).cuda(args.gpu)

    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            if args.gpu is None:
                checkpoint = torch.load(args.resume)
            else:
                loc = 'cuda:{}'.format(args.gpu)
                checkpoint = torch.load(args.resume, map_location=loc)
            args.start_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    if apex_support and args.fp16_precision:
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level='O2',
                                          keep_batchnorm_fp32=True)

    cudnn.benchmark = True

    train(model, train_loader, train_sampler, criterion, optimizer, scheduler,
          args, ngpus_per_node)
class SimCLRTrainer(object):
    def __init__(self, log_dir, model, optimizer, evaluator, device, params):
        self.model = model
        self.params = params
        self.device = device
        self.optimizer = optimizer
        self.evaluator = evaluator
        self.writer = SummaryWriter(log_dir)
        self.params = params
        self.nt_xent_criterion = NTXentLoss(self.device, params['batch_size'], **params['nce_loss'])

    def _step(self, model, xis, xjs, xs, n_iter):

        # get the representations and the projections
        zis = model(xis)  # [N,C]

        # get the representations and the projections
        zjs = model(xjs)  # [N,C]

        # normalize projection feature vectors
        zis = F.normalize(zis, dim=1)
        zjs = F.normalize(zjs, dim=1)

        if xs is not None:
            # Unaugmented datapoint. 
            zs = model(xs)
            zs = F.normalize(zs, dim=1)
        else:
            zs = None

        loss, loss_intra = self.nt_xent_criterion(zis, zjs, zs)
        return loss, loss_intra

    def train(self, train_dataset):
        train_loader = DataLoader(train_dataset, batch_size=self.params["batch_size"] * torch.cuda.device_count(),
                                  num_workers=self.params["num_workers"], drop_last=True, shuffle=False)

        model_checkpoints_folder = os.path.join(self.writer.log_dir, 'checkpoints')
        if not os.path.exists(model_checkpoints_folder):
            os.mkdir(model_checkpoints_folder)

        self.save_model(os.path.join(model_checkpoints_folder, 'model_000.pth'))

        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=len(train_loader), eta_min=0,
                                                               last_epoch=-1)

        margin = self.params["grad_combination_margin"]
        if margin is not None:
            matcher = re.compile(r"encoder.(\d+)")
            layers = dict()
            for name, _ in self.model.named_parameters():
                # print(f"{name}: {params.size()}")
                m = matcher.match(name)
                if m is None:
                    l = 10
                else:
                    l = int(m.group(1))
                layers[name] = l
            unique_entries = sorted(list(set(layers.values())))
            series = np.linspace(margin, 1 - margin, len(unique_entries)) 
            l2ratio = dict(zip(unique_entries, series))
            layer2ratio = { name : l2ratio[l] for name, l in layers.items() }

            log.info(f"Gradient margin: {margin}")
            for name, r in layer2ratio.items():
                log.info(f"  {name}: {r}")
        else:
            log.info("No gradient margin")

        n_iter = 0
        alpha = self.params["noise_blend"]

        for epoch_counter in range(self.params['max_epochs']):
            loss_record = []
            suffix = str(epoch_counter).zfill(3)

            # Add noise to weight once in a while
            if alpha > 0:
                for name, p in self.model.named_parameters():
                    with torch.no_grad():
                        if len(p.size()) < 2:
                            continue
                        w = torch.zeros_like(p, device=p.get_device())
                        torch.nn.init.xavier_uniform_(w)
                        p[:] = (1 - alpha) * p[:] + alpha * w

            for (xis, xjs, xs), _ in train_loader:
                xis = xis.to(self.device)
                xjs = xjs.to(self.device)

                if self.nt_xent_criterion.need_unaug_data():
                    xs = xs.to(self.device)
                else:
                    xs = None

                loss, loss_intra = self._step(self.model, xis, xjs, xs, n_iter)

                # if n_iter % self.params['log_every_n_steps'] == 0:
                #     self.writer.add_scalar('train_loss', loss, global_step=n_iter)

                all_loss = loss + loss_intra
                loss_record.append(all_loss.item())

                if margin is not None:
                    # Here we do backward twice for each loss and weight the gradient at different layer differently. 
                    self.optimizer.zero_grad()
                    loss.backward(retain_graph=True)

                    inter_grads = dict()
                    for name, p in self.model.named_parameters():
                        # print(f"{name}: {p.size()}")
                        inter_grads[name] = p.grad.clone()

                    self.optimizer.zero_grad()
                    loss_intra.backward()
                    for name, p in self.model.named_parameters():
                        r = layer2ratio[name]
                        # Lower layer -> high ratio of loss_intra
                        p.grad *= (1 - r)
                        p.grad += inter_grads[name] * r
                else:
                    self.optimizer.zero_grad()
                    all_loss.backward()

                self.optimizer.step()
                n_iter += 1

            # warmup for the first 10 epochs
            if epoch_counter >= 10:
                scheduler.step()
            self.writer.add_scalar('cosine_lr_decay', scheduler.get_lr()[0], global_step=n_iter)

            log.info(f"Epoch {epoch_counter}: numIter: {n_iter} Loss: {np.mean(loss_record)}")
            if self.evaluator is not None:
                best_acc = self.evaluator.eval_model(deepcopy(self.model))
                log.info(f"Epoch {epoch_counter}: best_acc: {best_acc}")

            if epoch_counter % self.params["save_per_epoch"] == 0:
                # save checkpoints
                self.save_model(os.path.join(model_checkpoints_folder, f'model_{suffix}.pth'))

    def save_model(self, PATH):
        torch.save({
            'online_network_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
        }, PATH)