예제 #1
0
def main():
    opt = parse_option()

    torch.cuda.set_device(opt.gpu)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

    encoder = SmallAlexNet(feat_dim=opt.feat_dim).to(opt.gpu)
    encoder.eval()
    train_loader, val_loader = get_data_loaders(opt)

    with torch.no_grad():
        sample, _ = train_loader.dataset[0]
        eval_numel = encoder(sample.unsqueeze(0).to(opt.gpu),
                             layer_index=opt.layer_index).numel()
    print(f'Feature dimension: {eval_numel}')

    encoder.load_state_dict(
        torch.load(opt.encoder_checkpoint, map_location=opt.gpu))
    print(f'Loaded checkpoint from {opt.encoder_checkpoint}')

    classifier = nn.Linear(eval_numel, 10).to(opt.gpu)

    optim = torch.optim.Adam(classifier.parameters(),
                             lr=opt.lr,
                             betas=(0.5, 0.999))
    scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optim, gamma=opt.lr_decay_rate, milestones=opt.lr_decay_epochs)

    loss_meter = AverageMeter('loss')
    it_time_meter = AverageMeter('iter_time')
    for epoch in range(opt.epochs):
        loss_meter.reset()
        it_time_meter.reset()
        t0 = time.time()
        for ii, (images, labels) in enumerate(train_loader):
            optim.zero_grad()
            with torch.no_grad():
                feats = encoder(images.to(opt.gpu),
                                layer_index=opt.layer_index).flatten(1)
            logits = classifier(feats)
            loss = F.cross_entropy(logits, labels.to(opt.gpu))
            loss_meter.update(loss, images.shape[0])
            loss.backward()
            optim.step()
            it_time_meter.update(time.time() - t0)
            if ii % opt.log_interval == 0:
                print(
                    f"Epoch {epoch}/{opt.epochs}\tIt {ii}/{len(train_loader)}\t{loss_meter}\t{it_time_meter}"
                )
            t0 = time.time()
        scheduler.step()
        val_acc = validate(opt, encoder, classifier, val_loader)
        print(f"Epoch {epoch}/{opt.epochs}\tval_acc {val_acc*100:.4g}%")
예제 #2
0
def main():
    opt = parse_option()

    print(
        f'Optimize: {opt.align_w:g} * loss_align(alpha={opt.align_alpha:g}) + {opt.unif_w:g} * loss_uniform(t={opt.unif_t:g})'
    )

    torch.cuda.set_device(opt.gpus[0])
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

    encoder = nn.DataParallel(
        SmallAlexNet(feat_dim=opt.feat_dim).to(opt.gpus[0]), opt.gpus)

    optim = torch.optim.SGD(encoder.parameters(),
                            lr=opt.lr,
                            momentum=opt.momentum,
                            weight_decay=opt.weight_decay)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optim, gamma=opt.lr_decay_rate, milestones=opt.lr_decay_epochs)

    loader = get_data_loader(opt)

    align_meter = AverageMeter('align_loss')
    unif_meter = AverageMeter('uniform_loss')
    loss_meter = AverageMeter('total_loss')
    it_time_meter = AverageMeter('iter_time')
    for epoch in range(opt.epochs):
        align_meter.reset()
        unif_meter.reset()
        loss_meter.reset()
        it_time_meter.reset()
        t0 = time.time()
        for ii, (im_x, im_y) in enumerate(loader):
            optim.zero_grad()
            x, y = encoder(
                torch.cat([im_x.to(opt.gpus[0]),
                           im_y.to(opt.gpus[0])])).chunk(2)
            align_loss_val = align_loss(x, y, alpha=opt.align_alpha)
            unif_loss_val = (uniform_loss(x, t=opt.unif_t) +
                             uniform_loss(y, t=opt.unif_t)) / 2
            loss = align_loss_val * opt.align_w + unif_loss_val * opt.unif_w
            align_meter.update(align_loss_val, x.shape[0])
            unif_meter.update(unif_loss_val)
            loss_meter.update(loss, x.shape[0])
            loss.backward()
            optim.step()
            it_time_meter.update(time.time() - t0)
            if ii % opt.log_interval == 0:
                print(
                    f"Epoch {epoch}/{opt.epochs}\tIt {ii}/{len(loader)}\t" +
                    f"{align_meter}\t{unif_meter}\t{loss_meter}\t{it_time_meter}"
                )
            t0 = time.time()
        scheduler.step()
    ckpt_file = os.path.join(opt.save_folder, 'encoder.pth')
    torch.save(encoder.module.state_dict(), ckpt_file)
    print(f'Saved to {ckpt_file}')
예제 #3
0
class Evaluator():
    def __init__(self,
                 data_loader,
                 logger,
                 config,
                 name='Evaluator',
                 metrics='classfication',
                 summary_writer=None):
        self.data_loader = data_loader
        self.logger = logger
        self.name = name
        self.summary_writer = summary_writer
        self.step = 0
        self.config = config
        self.log_frequency = config.log_frequency
        self.loss_meters = AverageMeter()
        self.acc_meters = AverageMeter()
        self.acc5_meters = AverageMeter()
        self.report_metrics = self.classfication_metrics if metrics == 'classfication' else self.regression_metrics
        return

    def log(self, epoch, GLOBAL_STEP):
        display = log_display(epoch=epoch,
                              global_step=GLOBAL_STEP,
                              time_elapse=self.time_used,
                              **self.logger_payload)
        self.logger.info(display)

    def eval(self, epoch, GLOBAL_STEP, model, criterion):
        for i, (images, labels) in enumerate(self.data_loader):
            self.eval_batch(x=images,
                            y=labels,
                            model=model,
                            criterion=criterion)
        self.log(epoch, GLOBAL_STEP)
        return

    def eval_batch(self, x, y, model, criterion):
        model.eval()
        x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
        start = time.time()
        with torch.no_grad():
            pred = model(x)
            loss = criterion(pred, y)
        end = time.time()
        self.time_used = end - start
        self.step += 1
        self.report_metrics(pred, y, loss)
        return

    def classfication_metrics(self, x, y, loss):
        acc, acc5 = accuracy(x, y, topk=(1, 5))
        self.loss_meters.update(loss.item(), y.shape[0])
        self.acc_meters.update(acc.item(), y.shape[0])
        self.acc5_meters.update(acc5.item(), y.shape[0])
        self.logger_payload = {
            "acc": acc,
            "acc_avg": self.acc_meters.avg,
            "top5_acc": acc5,
            "top5_acc_avg": self.acc5_meters.avg,
            "loss": loss,
            "loss_avg": self.loss_meters.avg
        }

        if self.summary_writer is not None:
            self.summary_writer.add_scalar(os.path.join(self.name, 'acc'), acc,
                                           self.step)
            self.summary_writer.add_scalar(os.path.join(self.name, 'loss'),
                                           loss, self.step)

    def regression_metrics(self, x, y, loss):
        diff = abs((x - y).mean().detach().item())
        self.loss_meters.update(loss.item(), y.shape[0])
        self.acc_meters.update(diff, y.shape[0])
        self.logger_payload = {
            "|diff|": diff,
            "|diff_avg|": self.acc_meters.avg,
            "loss": loss,
            "loss_avg": self.loss_meters.avg
        }

        if self.summary_writer is not None:
            self.summary_writer.add_scalar(os.path.join(self.name, 'diff'),
                                           diff, self.step)
            self.summary_writer.add_scalar(os.path.join(self.name, 'loss'),
                                           loss, self.step)

    def _reset_stats(self):
        self.loss_meters.reset()
        self.acc_meters.reset()
        self.acc5_meters.reset()
예제 #4
0
class Trainer():
    def __init__(self,
                 data_loader,
                 logger,
                 config,
                 name='Trainer',
                 metrics='classfication'):
        self.data_loader = data_loader
        self.logger = logger
        self.name = name
        self.step = 0
        self.config = config
        self.log_frequency = config.log_frequency
        self.loss_meters = AverageMeter()
        self.acc_meters = AverageMeter()
        self.acc5_meters = AverageMeter()
        self.report_metrics = self.classfication_metrics if metrics == 'classfication' else self.regression_metrics

    def train(self, epoch, GLOBAL_STEP, model, optimizer, criterion):
        model.train()
        for images, labels in self.data_loader:
            images, labels = images.to(device, non_blocking=True), labels.to(
                device, non_blocking=True)
            self.train_batch(images, labels, model, criterion, optimizer)
            self.log(epoch, GLOBAL_STEP)
            GLOBAL_STEP += 1
        return GLOBAL_STEP

    def train_batch(self, x, y, model, criterion, optimizer):
        start = time.time()
        model.zero_grad()
        optimizer.zero_grad()
        pred = model(x)
        loss = criterion(pred, y)
        loss.backward()
        grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   self.config.grad_bound)
        optimizer.step()
        self.report_metrics(pred, y, loss)
        self.logger_payload['lr'] = optimizer.param_groups[0]['lr'],
        self.logger_payload['|gn|'] = grad_norm
        end = time.time()
        self.step += 1
        self.time_used = end - start

    def log(self, epoch, GLOBAL_STEP):
        if GLOBAL_STEP % self.log_frequency == 0:
            display = log_display(epoch=epoch,
                                  global_step=GLOBAL_STEP,
                                  time_elapse=self.time_used,
                                  **self.logger_payload)
            self.logger.info(display)

    def classfication_metrics(self, x, y, loss):
        acc, acc5 = accuracy(x, y, topk=(1, 5))
        self.loss_meters.update(loss.item(), y.shape[0])
        self.acc_meters.update(acc.item(), y.shape[0])
        self.acc5_meters.update(acc5.item(), y.shape[0])
        self.logger_payload = {
            "acc": acc,
            "acc_avg": self.acc_meters.avg,
            "loss": loss,
            "loss_avg": self.loss_meters.avg
        }

    def regression_metrics(self, x, y, loss):
        diff = abs((x - y).mean().detach().item())
        self.loss_meters.update(loss.item(), y.shape[0])
        self.acc_meters.update(diff, y.shape[0])
        self.logger_payload = {
            "|diff|": diff,
            "|diff_avg|": self.acc_meters.avg,
            "loss": loss,
            "loss_avg": self.loss_meters.avg
        }

    def _reset_stats(self):
        self.loss_meters.reset()
        self.acc_meters.reset()
        self.acc5_meters.reset()
예제 #5
0
def train():
    # Check NNabla version
    if get_nnabla_version_integer() < 11900:
        raise ValueError(
            'Please update the nnabla version to v1.19.0 or latest version since memory efficiency of core engine is improved in v1.19.0'
        )

    parser, args = get_train_args()

    # Get context.
    ctx = get_extension_context(args.context, device_id=args.device_id)
    comm = CommunicatorWrapper(ctx)
    nn.set_default_context(comm.ctx)
    ext = import_extension_module(args.context)

    # Monitors
    # setting up monitors for logging
    monitor_path = os.path.join(args.output, args.target)
    monitor = Monitor(monitor_path)

    monitor_traing_loss = MonitorSeries('Training loss', monitor, interval=1)
    monitor_lr = MonitorSeries('learning rate', monitor, interval=1)
    monitor_time = MonitorTimeElapsed("training time per epoch",
                                      monitor,
                                      interval=1)

    if comm.rank == 0:
        if not os.path.isdir(args.output):
            os.makedirs(args.output)

    # Initialize DataIterator for MUSDB.
    train_source, args = load_datasources(parser, args)

    train_iter = data_iterator(train_source,
                               args.batch_size,
                               RandomState(args.seed),
                               with_memory_cache=False)

    if comm.n_procs > 1:
        train_iter = train_iter.slice(rng=None,
                                      num_of_slices=comm.n_procs,
                                      slice_pos=comm.rank)

    # Change max_iter, learning_rate and weight_decay according no. of gpu devices for multi-gpu training.
    default_batch_size = 6
    train_scale_factor = (comm.n_procs * args.batch_size) / default_batch_size

    max_iter = int(train_source._size // (comm.n_procs * args.batch_size))
    weight_decay = args.weight_decay * train_scale_factor
    args.lr = args.lr * train_scale_factor

    print(f"max_iter per GPU-device:{max_iter}")

    # Calculate the statistics (mean and variance) of the dataset
    scaler_mean, scaler_std = get_statistics(args, train_source)

    # clear cache memory
    ext.clear_memory_cache()

    # Create input variables.
    mixture_audio = nn.Variable([args.batch_size] +
                                list(train_source._get_data(0)[0].shape))
    target_audio = nn.Variable([args.batch_size] +
                               list(train_source._get_data(0)[1].shape))

    with open(f"./configs/{args.target}.yaml") as file:
        # Load target specific Hyper parameters
        hparams = yaml.load(file, Loader=yaml.FullLoader)

    # create training graph
    mix_spec = spectogram(*stft(mixture_audio,
                                n_fft=hparams['fft_size'],
                                n_hop=hparams['hop_size'],
                                patch_length=256),
                          mono=(hparams['n_channels'] == 1))
    target_spec = spectogram(*stft(target_audio,
                                   n_fft=hparams['fft_size'],
                                   n_hop=hparams['hop_size'],
                                   patch_length=256),
                             mono=(hparams['n_channels'] == 1))

    with nn.parameter_scope(args.target):
        d3net = D3NetMSS(hparams,
                         comm=comm.comm,
                         input_mean=scaler_mean,
                         input_scale=scaler_std,
                         init_method='xavier')
        pred_spec = d3net(mix_spec)

    loss = F.mean(F.squared_error(pred_spec, target_spec))
    loss.persistent = True

    # Create Solver and set parameters.
    solver = S.Adam(args.lr)
    solver.set_parameters(nn.get_parameters())

    # Initialize LR Scheduler (AnnealingScheduler)
    lr_scheduler = AnnealingScheduler(init_lr=args.lr,
                                      anneal_steps=[40],
                                      anneal_factor=0.1)

    # AverageMeter for mean loss calculation over the epoch
    losses = AverageMeter()

    for epoch in range(args.epochs):
        # TRAINING
        losses.reset()
        for batch in range(max_iter):
            mixture_audio.d, target_audio.d = train_iter.next()
            solver.zero_grad()
            loss.forward(clear_no_need_grad=True)
            if comm.n_procs > 1:
                all_reduce_callback = comm.get_all_reduce_callback()
                loss.backward(clear_buffer=True,
                              communicator_callbacks=all_reduce_callback)
            else:
                loss.backward(clear_buffer=True)
            solver.weight_decay(weight_decay)
            solver.update()
            losses.update(loss.d.copy(), args.batch_size)
        training_loss = losses.get_avg()

        # clear cache memory
        ext.clear_memory_cache()

        lr = lr_scheduler.get_learning_rate(epoch)
        solver.set_learning_rate(lr)

        if comm.rank == 0:
            monitor_traing_loss.add(epoch, training_loss)
            monitor_lr.add(epoch, lr)
            monitor_time.add(epoch)

            # save intermediate weights
            nn.save_parameters(f"{os.path.join(args.output, args.target)}.h5")

    if comm.rank == 0:
        # save final weights
        nn.save_parameters(
            f"{os.path.join(args.output, args.target)}_final.h5")