Example #1
0
def train_one_epoch(args, train_loader, model, optimizer, weights=None):
    if args.loss.startswith('weighted'): weights = weights.to(args.device)
    losses = AverageMeter()
    model.train()
    if args.accumulation_steps > 1:
        print(
            f"Due to gradient accumulation of {args.accumulation_steps} using global batch size of {args.accumulation_steps*train_loader.batch_size}"
        )
        optimizer.zero_grad()
    tk0 = tqdm(train_loader, total=len(train_loader))
    for b_idx, data in enumerate(tk0):
        for key, value in data.items():
            data[key] = value.to(args.device)
        if args.accumulation_steps == 1 and b_idx == 0:
            optimizer.zero_grad()
        _, loss = model(**data, args=args, weights=weights)

        with torch.set_grad_enabled(True):
            loss.backward()
            if (b_idx + 1) % args.accumulation_steps == 0:
                optimizer.step()
                optimizer.zero_grad()
        losses.update(loss.item(), train_loader.batch_size)
        tk0.set_postfix(loss=losses.avg)
    return losses.avg
Example #2
0
    def train(
        data_loader,
        model,
        optimizer,
        device,
        scheduler=None,
        accumulation_steps=1,
        use_tpu=False,
        fp16=False,
    ):
        if use_tpu and __xla_available:
            raise Exception(
                "You want to use TPUs but you dont have pytorch_xla installed")
        if fp16 and __apex_available:
            raise Exception(
                "You want to use fp16 but you dont have apex installed")
        if fp16 and use_tpu:
            raise Exception("Apex fp16 is not available when using TPUs")
        if fp16:
            accumulation_steps = 1
        losses = AverageMeter()
        predictions = []
        model.train()
        if accumulation_steps > 1:
            optimizer.zero_grad()
        tk0 = tqdm(data_loader, total=len(data_loader), disable=use_tpu)
        for b_idx, data in enumerate(tk0):
            for key, value in data.items():
                data[key] = value.to(device)
            if accumulation_steps == 1 and b_idx == 0:
                optimizer.zero_grad()
            _, loss = model(**data)

            if not use_tpu:
                with torch.set_grad_enabled(True):
                    if fp16:
                        with amp.scale_loss(loss, optimizer) as scaled_loss:
                            scaled_loss.backward()
                    else:
                        loss.backward()
                    if (b_idx + 1) % accumulation_steps == 0:
                        optimizer.step()
                        if scheduler is not None:
                            scheduler.step()
                        if b_idx > 0:
                            optimizer.zero_grad()
            else:
                loss.backward()
                xm.optimizer_step(optimizer)
                if scheduler is not None:
                    scheduler.step()
                if b_idx > 0:
                    optimizer.zero_grad()

            losses.update(loss.item(), data_loader.batch_size)
            tk0.set_postfix(loss=losses.avg)
        return losses.avg
Example #3
0
 def evaluate(data_loader, model, device, use_tpu=False):
     losses = AverageMeter()
     final_predictions = []
     model.eval()
     with torch.no_grad():
         tk0 = tqdm(data_loader, total=len(data_loader), disable=use_tpu)
         for b_idx, data in enumerate(tk0):
             for key, value in data.items():
                 data[key] = value.to(device)
             predictions, loss = model(**data)
             predictions = predictions.cpu()
             losses.update(loss.item(), data_loader.batch_size)
             final_predictions.append(predictions)
             tk0.set_postfix(loss=losses.avg)
     return final_predictions, losses.avg
Example #4
0
def evaluate(args, valid_loader, model):
    losses = AverageMeter()
    final_preds = []
    model.eval()
    with torch.no_grad():
        tk0 = tqdm(valid_loader, total=len(valid_loader))
        for data in tk0:
            for key, value in data.items():
                data[key] = value.to(args.device)
            preds, loss = model(**data, args=args)
            if args.loss == 'crossentropy' or args.loss == 'weighted_cross_entropy':
                preds = preds.argmax(1)
            losses.update(loss.item(), valid_loader.batch_size)
            preds = preds.cpu().numpy()
            final_preds.extend(preds)
            tk0.set_postfix(loss=losses.avg)
    return final_preds, losses.avg
    def train(
        data_loader,
        model,
        optimizer,
        device,
        scheduler=None,
        accumulation_steps=1,
    ):
        losses = AverageMeter()
        predictions = []
        model.train()
        if accumulation_steps > 1:
            optimizer.zero_grad()
        tk0 = tqdm(data_loader, total=len(data_loader), disable=False)
        # import pdb; pdb.set_trace()
        for b_idx, data in enumerate(tk0):
            for key, value in data.items():
                data[key] = value.to(device)
            
            if accumulation_steps == 1 and b_idx == 0:
                optimizer.zero_grad()
            
            _, loss = model(**data)

            with torch.set_grad_enabled(True):
                loss.backward()
                optimizer.step()

                if scheduler is not None:
                    scheduler.step()
                if b_idx > 0:
                    optimizer.zero_grad()
            
            losses.update(loss.item(), data_loader.batch_size)
            tk0.set_postfix(loss=losses.avg)
        
        return losses.avg
    def train_epoch(self, epoch, phase):
        loss_ = AverageMeter()
        accuracy_ = AverageMeter()
        self.model.train()
        self.margin.train()

        for batch_idx, sample in enumerate(self.dataloaders[phase]):

            imageL, imageR, label = sample[0].to(self.device), \
                                    sample[1].to(self.device), sample[2].to(self.device)
            self.optimizer.zero_grad()
            with torch.set_grad_enabled(True):
                outputL, outputR = self.model(imageL), self.model(imageR)

                acc = 0
                loss = self.criterion([outputL, outputR], label)

                loss.backward()
                self.optimizer.step()

            loss_.update(loss, label.size(0))
            accuracy_.update(acc, label.size(0))

            if batch_idx % 40 == 0:
                print(
                    'Train Epoch: {} [{:08d}/{:08d} ({:02.0f}%)]\tLoss:{:.6f}\tAcc:{:.6f} LR:{:.7f}'
                    .format(epoch, batch_idx * len(label),
                            len(self.dataloaders[phase].dataset),
                            100. * batch_idx / len(self.dataloaders[phase]),
                            loss.item(), 0,
                            self.optimizer.param_groups[0]['lr']))

        self.scheduler.step()

        print("Train Epoch Loss: {:.6f} Accuracy: {:.6f}".format(
            loss_.avg, accuracy_.avg))
        torch.save(
            self.model.state_dict(),
            './checkpoints/{}_{}_Contrastive_{:04d}.pth'.format(
                self.ckpt_tag, str(self.margin), epoch))
        torch.save(
            self.margin.state_dict(),
            './checkpoints/{}_512_{}_Contrastive_{:04d}.pth'.format(
                self.ckpt_tag, str(self.margin), epoch))
Example #7
0
def train(net, criterion, optimizer, writer, epoch, n_iter, loss_, t0):
    train_pos_dist = AverageMeter()
    train_neg_dist = AverageMeter()
    net.train()
    for batch_idx, (data1, data2, data3) in enumerate(train_loader):
        data1, data2, data3 = data1.cuda().float(), data2.cuda().float(
        ), data3.cuda().float()
        embedded_a, embedded_p, embedded_n = net(data1, data2, data3)

        dista, distb, loss_triplet, loss_total = criterion(
            embedded_a, embedded_p, embedded_n)
        loss_embedd = embedded_a.norm(2) + embedded_p.norm(
            2) + embedded_n.norm(2)
        loss = loss_triplet + 0.001 * loss_embedd

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        loss_ += loss_total.item()

        train_pos_dist.update(dista.cpu().data.numpy().sum())
        train_neg_dist.update(distb.cpu().data.numpy().sum())
        writer.add_scalar('Train/Loss_Triplet', loss_triplet, n_iter)
        writer.add_scalar('Train/Loss_Embedd', loss_embedd, n_iter)
        writer.add_scalar('Train/Loss', loss, n_iter)
        writer.add_scalar('Train/Distance/Positive', train_pos_dist.avg,
                          n_iter)
        writer.add_scalar('Train/Distance/Negative', train_neg_dist.avg,
                          n_iter)
        n_iter += 1

        if batch_idx % 5 == 4:
            t1 = time.time()
            print('[Epoch %d, Batch %4d] loss: %.8f time: %.5f lr: %.3e' %
                  (epoch + 1, batch_idx + 1, loss_ / 5, (t1 - t0) / 60, lr))
            t0 = t1
            loss_ = 0.0
    return n_iter
Example #8
0
def validate(val_loader, model, criterion, save_images, epoch, device):
    model.eval()

    batch_time, data_time, losses = AverageMeter(), AverageMeter(
    ), AverageMeter()

    end = time.time()
    already_saved_images = False
    for i, (input_gray, input_ab, target) in enumerate(val_loader):
        data_time.update(time.time() - end)

        # Use GPU
        input_gray, input_ab, target, model = input_gray.to(
            device), input_ab.to(device), target.to(device), model.to(device)

        output_ab = model(input_gray)
        loss = criterion(output_ab, input_ab)
        losses.update(loss.item(), input_gray.size(0))

        if save_images and not already_saved_images:
            already_saved_images = True
            for j in range(min(len(output_ab), 10)):
                save_path = {
                    'grayscale': 'outputs/gray/',
                    'colorized': 'outputs/color/'
                }
                save_name = 'img-{}-epoch-{}.jpg'.format(
                    i * val_loader.batch_size + j, epoch)
                convert_to_rgb(input_gray[j].cpu(),
                               ab_input=output_ab[j].detach().cpu(),
                               save_path=save_path,
                               save_name=save_name)

        batch_time.update(time.time() - end)
        end = time.time()

        print('Validate: [{0}/{1}]\t'
              'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
              'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(
                  i, len(val_loader), batch_time=batch_time, loss=losses))

    print('Finished validation.')
    return losses.avg
Example #9
0
def train(train_loader, model, criterion, optimizer, epoch, device):
    print('Starting training epoch {}'.format(epoch))
    model.train()

    batch_time, data_time, losses = AverageMeter(), AverageMeter(
    ), AverageMeter()

    end = time.time()
    for i, (input_gray, input_ab, target) in enumerate(train_loader):
        input_gray, input_ab, target, model = input_gray.to(
            device), input_ab.to(device), target.to(device), model.to(device)

        data_time.update(time.time() - end)

        output_ab = model(input_gray)
        loss = criterion(output_ab, input_ab)
        losses.update(loss.item(), input_gray.size(0))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        batch_time.update(time.time() - end)
        end = time.time()

        print('Epoch: [{0}][{1}/{2}]\t'
              'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
              'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
              'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(
                  epoch,
                  i,
                  len(train_loader),
                  batch_time=batch_time,
                  data_time=data_time,
                  loss=losses))

    print('Finished training epoch {}'.format(epoch))
Example #10
0
def train(model: HiDDen, device: torch.device, this_run_folder: str,
          hidden_config: HiDDenConfiguration, train_options: TrainingOptions):
    train_data, val_data = hidden_utils.get_data_loaders(
        hidden_config, train_options)
    file_count = len(train_data.dataset)
    steps_in_epoch = file_count // train_options.batch_size \
                    + int(file_count % train_options.batch_size != 0)

    print_each = 10
    saved_images_size = (512, 512)

    for epoch in range(train_options.number_of_epochs):
        logging.info(
            f'\nStarting epoch {epoch + 1} / {train_options.number_of_epochs}')
        logging.info(f'Batch size = {train_options.batch_size}')
        logging.info(f'Steps in epoch {steps_in_epoch}')
        losses_accu = {}
        epoch_start = time.time()
        step = 1

        for image, _ in train_data:
            image = image.to(device)
            message = torch.Tensor(
                np.random.choice(
                    [0, 1],
                    (image.shape[0], hidden_config.message_length))).to(device)
            losses, _ = model.train_on_batch([image, message])

            if not losses_accu:
                for name in losses:
                    losses_accu[name] = AverageMeter()

            for name, loss in losses.items():
                losses_accu[name].update(loss)

            if step % print_each == 0 or step == steps_in_epoch:
                logging.info(
                    f'Epoch: {epoch + 1}/{train_options.number_of_epochs} Step: {step}/{steps_in_epoch}'
                )
                hidden_utils.log_progress(losses_accu)
                logging.info('-' * 40)
            step += 1

        train_duration = time.time() - epoch_start
        logging.info(
            f'Epoch {epoch + 1} training duration {train_duration:.2f}')
        logging.info('-' * 40)
        hidden_utils.write_losses(os.path.join(this_run_folder, 'train.csv'),
                                  losses_accu, epoch, train_duration)

        logging.info(
            f'Running validation for epoch {epoch + 1} / {train_options.number_of_epochs}'
        )
        for image, _ in val_data:
            image = image.to(device)
            message = torch.Tensor(
                np.random.choice(
                    [0, 1],
                    (image.shape[0], hidden_config.message_length))).to(device)
        losses, (encoded_images,
                 decoded_messages) = model.validate_on_batch([image, message])
        hidden_utils.log_progress(losses_accu)
        logging.info('-' * 40)
        hidden_utils.save_checkpoint(
            model, train_options.experiment_name, epoch,
            os.path.join(this_run_folder, 'checkpoints'))
        hidden_utils.write_losses(
            os.path.join(this_run_folder, 'validation.csv'), losses_accu,
            epoch,
            time.time() - epoch_start)
Example #11
0
def validate(val_loader, model, criterion):
    model.decoderRNN.eval()  # eval mode (no dropout or batchnorm)
    model.encoderCNN.eval()
    batch_time = AverageMeter()
    losses = AverageMeter()
    top5accs = AverageMeter()

    start = time.time()

    references = list(
    )  # references (true captions) for calculating BLEU-4 score
    hypotheses = list()  # hypotheses (predictions)

    with T.no_grad():
        # Batches
        i = 0
        for _, (imgs, captions) in enumerate(val_loader):
            i += 1
            # Move to device, if available
            imgs = imgs.to(Constants.device)
            captions = captions.to(Constants.device)
            # Forward prop.
            outputs = model(imgs, captions[:-1])
            vocab_size = outputs.shape[2]
            outputs1 = outputs.reshape(-1, vocab_size)
            captions1 = captions.reshape(-1)
            loss = criterion(outputs1, captions1)

            # Keep track of metrics
            losses.update(loss.item(), len(captions1))
            top5 = accuracy(outputs1, captions1, 5)
            top5accs.update(top5, len(captions1))
            batch_time.update(time.time() - start)

            start = time.time()

            if i % Hyper.print_freq == 0:
                print(
                    'Validation: [{0}/{1}]\t'
                    'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                    'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                    'Top-5 Accuracy {top5.val:.3f} ({top5.avg:.3f})\t'.format(
                        i,
                        len(val_loader),
                        batch_time=batch_time,
                        loss=losses,
                        top5=top5accs))

            # Store references (true captions), and hypothesis (prediction) for each image
            reference = get_sentence(captions1, model)
            references.append(reference)
            prediction = get_hypothesis(outputs1, model)
            hypotheses.append(prediction)

        # Calculate BLEU-4 scores
        bleu4 = corpus_bleu(references, hypotheses)

        print(
            f'\n * LOSS - {losses.avg}, TOP-5 ACCURACY - {top5accs.avg}, BLEU-4 - {bleu4}\n'
        )

    return bleu4
Example #12
0
    def train_epoch(self):
        """
        Train the model for one epoch.
        """

        print("-------------- Train epoch ------------------")

        batch_time = AverageMeter()
        data_time = AverageMeter()
        forward_time = AverageMeter()
        loss_time = AverageMeter()
        backward_time = AverageMeter()
        loss_total_am = AverageMeter()
        loss_loc_am = AverageMeter()
        loss_cls_am = AverageMeter()

        # switch to training mode
        self.model_dp.train()

        is_lr_change = self.epoch in [epoch for epoch, _ in self.lr_scales]
        if self.optimizer is None or is_lr_change:
            scale = None
            if self.optimizer is None:
                scale = 1.0
            if is_lr_change:
                scale = [
                    sc for epoch, sc in self.lr_scales if epoch == self.epoch
                ][0]
            self.learning_rate = self.base_learning_rate * scale
            if self.optimizer is None:
                self.optimizer = torch.optim.SGD(self.model_dp.parameters(),
                                                 self.learning_rate,
                                                 momentum=0.9,
                                                 weight_decay=0.0001)
            else:
                for param_group in self.optimizer.param_groups:
                    param_group['lr'] = self.learning_rate

        do_dump_train_images = False
        detection_train_dump_dir = None
        if do_dump_train_images:
            detection_train_dump_dir = os.path.join(self.run_dir,
                                                    'detection_train_dump')
            clean_dir(detection_train_dump_dir)

        end = time.time()
        for batch_idx, sample in enumerate(self.train_loader):
            # measure data loading time
            data_time.update(time.time() - end)

            input, target, names, pil_images, annotations, stats = sample

            if do_dump_train_images:  # and random.random() < 0.01:
                dump_images(names, pil_images, annotations, None, stats,
                            self.model.labelmap, detection_train_dump_dir)

            input_var, target_var = self.wrap_sample_with_variable(
                input, target)

            # compute output
            forward_ts = time.time()
            encoded_tensor = self.model_dp(input_var)
            forward_time.update(time.time() - forward_ts)
            loss_ts = time.time()
            loss, loss_details = self.model.get_loss(encoded_tensor,
                                                     target_var)
            loss_time.update(time.time() - loss_ts)

            # record loss
            loss_total_am.update(loss_details["loss"], input.size(0))
            loss_loc_am.update(loss_details["loc_loss"], input.size(0))
            loss_cls_am.update(loss_details["cls_loss"], input.size(0))

            # compute gradient and do SGD step
            backward_ts = time.time()
            self.optimizer.zero_grad()
            loss.backward()
            clip_gradient(self.model, 2.0, 'by_max')
            self.optimizer.step()
            backward_time.update(time.time() - backward_ts)

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if batch_idx % self.print_freq == 0:
                print(
                    'Epoch: [{0}][{1}/{2}]\t'
                    'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                    'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                    'Forward {forward_time.val:.3f} ({forward_time.avg:.3f})\t'
                    'LossTime {loss_time.val:.3f} ({loss_time.avg:.3f})\t'
                    'Backward {backward_time.val:.3f} ({backward_time.avg:.3f})\t'
                    'Loss {loss_total_am.val:.4f} ({loss_total_am.avg:.4f})\t'
                    'Loss_loc {loss_loc_am.val:.4f} ({loss_loc_am.avg:.4f})\t'
                    'Loss_cls {loss_cls_am.val:.4f} ({loss_cls_am.avg:.4f})\t'.
                    format(self.epoch,
                           batch_idx,
                           len(self.train_loader),
                           batch_time=batch_time,
                           data_time=data_time,
                           forward_time=forward_time,
                           loss_time=loss_time,
                           backward_time=backward_time,
                           loss_total_am=loss_total_am,
                           loss_loc_am=loss_loc_am,
                           loss_cls_am=loss_cls_am))

            if self.train_iter % self.print_freq == 0:
                self.writer.add_scalar('train/loss', loss_total_am.avg,
                                       self.train_iter)
                self.writer.add_scalar('train/loss_loc', loss_loc_am.avg,
                                       self.train_iter)
                self.writer.add_scalar('train/loss_cls', loss_cls_am.avg,
                                       self.train_iter)
                self.writer.add_scalar('train/lr', self.learning_rate,
                                       self.train_iter)

                num_prints = self.train_iter // self.print_freq
                # print('num_prints=', num_prints)
                num_prints_rare = num_prints // 100
                # print('num_prints_rare=', num_prints_rare)
                if num_prints_rare == 0 and num_prints % 10 == 0 or num_prints % 100 == 0:
                    print('save historgams')
                    if self.train_iter > 0:
                        import itertools
                        named_parameters = itertools.chain(
                            self.model.multibox_layers.named_parameters(),
                            self.model.extra_layers.named_parameters(),
                        )
                        for name, param in named_parameters:
                            self.writer.add_histogram(
                                name,
                                param.detach().cpu().numpy(),
                                self.train_iter,
                                bins='fd')
                            self.writer.add_histogram(
                                name + '_grad',
                                param.grad.detach().cpu().numpy(),
                                self.train_iter,
                                bins='fd')

                    first_conv = list(self.model.backbone._modules.items()
                                      )[0][1]._parameters['weight']
                    image_grid = torchvision.utils.make_grid(
                        first_conv.detach().cpu(),
                        normalize=True,
                        scale_each=True)
                    image_grid_grad = torchvision.utils.make_grid(
                        first_conv.grad.detach().cpu(),
                        normalize=True,
                        scale_each=True)
                    self.writer.add_image('layers0_conv', image_grid,
                                          self.train_iter)
                    self.writer.add_image('layers0_conv_grad', image_grid_grad,
                                          self.train_iter)

            self.train_iter += 1
            pass

        self.epoch += 1
Example #13
0
    viz = Visualize()
    viz.visualizeRawPointCloud(first_anchor, True)
    viz.visualizeSphere(first_anchor, True)


# ## Train model
def adjust_learning_rate_exp(optimizer, epoch_num, lr):
    decay_rate = 0.96
    new_lr = lr * math.pow(decay_rate, epoch_num)
    for param_group in optimizer.param_groups:
        param_group['lr'] = new_lr

    return new_lr


val_accs = AverageMeter()
loss_ = 0


def accuracy(dista, distb):
    margin = 0
    pred = (dista - distb - margin).cpu().data
    acc = ((pred < 0).sum()).float() / dista.size(0)
    return acc


def train(net, criterion, optimizer, writer, epoch, n_iter, loss_, t0):
    train_pos_dist = AverageMeter()
    train_neg_dist = AverageMeter()
    net.train()
    for batch_idx, (data1, data2, data3) in enumerate(train_loader):
def main():
    # device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    device = torch.device('cpu')

    parser = argparse.ArgumentParser(description='Training of HiDDeN nets')
    # parser.add_argument('--size', '-s', default=128, type=int, help='The size of the images (images are square so this is height and width).')
    parser.add_argument('--data-dir', '-d', required=True, type=str, help='The directory where the data is stored.')
    parser.add_argument('--runs_root', '-r', default=os.path.join('.', 'experiments'), type=str,
                        help='The root folder where data about experiments are stored.')
    parser.add_argument('--batch-size', '-b', default=1, type=int, help='Validation batch size.')

    args = parser.parse_args()
    print_each = 25

    completed_runs = [o for o in os.listdir(args.runs_root)
                      if os.path.isdir(os.path.join(args.runs_root, o)) and o != 'no-noise-defaults']

    print(completed_runs)

    write_csv_header = True
    for run_name in completed_runs:
        current_run = os.path.join(args.runs_root, run_name)
        print(f'Run folder: {current_run}')
        options_file = os.path.join(current_run, 'options-and-config.pickle')
        train_options, hidden_config, noise_config = utils.load_options(options_file)
        train_options.train_folder = os.path.join(args.data_dir, 'val')
        train_options.validation_folder = os.path.join(args.data_dir, 'val')
        train_options.batch_size = args.batch_size
        checkpoint, chpt_file_name = utils.load_last_checkpoint(os.path.join(current_run, 'checkpoints'))
        print(f'Loaded checkpoint from file {chpt_file_name}')

        noiser = Noiser(noise_config)
        model = Hidden(hidden_config, device, noiser, tb_logger=None)
        utils.model_from_checkpoint(model, checkpoint)

        print('Model loaded successfully. Starting validation run...')
        _, val_data = utils.get_data_loaders(hidden_config, train_options)
        file_count = len(val_data.dataset)
        if file_count % train_options.batch_size == 0:
            steps_in_epoch = file_count // train_options.batch_size
        else:
            steps_in_epoch = file_count // train_options.batch_size + 1

        losses_accu = {}
        step = 0
        for image, _ in val_data:
            step += 1
            image = image.to(device)
            message = torch.Tensor(np.random.choice([0, 1], (image.shape[0], hidden_config.message_length))).to(device)
            losses, (encoded_images, noised_images, decoded_messages) = model.validate_on_batch([image, message],
                                                                                                set_eval_mode=True)
            if not losses_accu:  # dict is empty, initialize
                for name in losses:
                    losses_accu[name] = AverageMeter()
            for name, loss in losses.items():
                losses_accu[name].update(loss)
            if step % print_each == 0 or step == steps_in_epoch:
                print(f'Step {step}/{steps_in_epoch}')
                utils.print_progress(losses_accu)
                print('-' * 40)

        # utils.print_progress(losses_accu)
        write_validation_loss(os.path.join(args.runs_root, 'validation_run.csv'), losses_accu, run_name,
                              checkpoint['epoch'],
                              write_header=write_csv_header)
        write_csv_header = False
Example #15
0
def main():
    axis = 'ax1'
    # CUDA for PyTorch
    device = train_device()

    # Training dataset
    train_params = {'batch_size': 10, 'shuffle': True, 'num_workers': 4}

    data_path = './dataset/dataset_' + axis + '/train/'
    train_dataset = Dataset(data_path,
                            transform=transforms.Compose([Preprocessing()]))
    lenght = int(len(train_dataset))
    train_loader = torch.utils.data.DataLoader(train_dataset, **train_params)

    # Validation dataset
    data_path = './dataset/dataset_' + axis + '/valid/'
    valid_dataset = Dataset(data_path,
                            transform=transforms.Compose([Preprocessing()]))
    valid_params = {'batch_size': 10, 'shuffle': True, 'num_workers': 4}
    val_loader = torch.utils.data.DataLoader(valid_dataset, **valid_params)

    # Training params
    learning_rate = 1e-4
    max_epochs = 100

    # Used pretrained model and modify channels from 3 to 1
    model = torch.hub.load('mateuszbuda/brain-segmentation-pytorch',
                           'unet',
                           in_channels=3,
                           out_channels=1,
                           init_features=32,
                           pretrained=True)
    model.encoder1.enc1conv1 = nn.Conv2d(1,
                                         32,
                                         kernel_size=(3, 3),
                                         stride=(1, 1),
                                         padding=(1, 1),
                                         bias=False)
    model.to(device)

    # Optimizer and loss function
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    dsc_loss = DiceLoss()

    # Metrics
    train_loss = AverageMeter('Training loss', ':.6f')
    val_loss = AverageMeter('Validation loss', ':.6f')
    best_loss = float('inf')
    nb_of_batches = lenght // train_params['batch_size']

    for epoch in range(max_epochs):
        val_loss.avg = 0
        train_loss.avg = 0
        if not epoch:
            logg_file = loggs.Loggs(['epoch', 'train_loss', 'val_loss'])
            model.train()
        for i, (image, label) in enumerate(train_loader):
            torch.cuda.empty_cache()
            image, label = image.to(device), label.to(device)
            optimizer.zero_grad()
            y_pred = model(image)
            loss = dsc_loss(y_pred, label)
            del y_pred
            train_loss.update(loss.item(), image.size(0))
            loss.backward()
            optimizer.step()
            loggs.training_bar(i,
                               nb_of_batches,
                               prefix='Epoch: %d/%d' % (epoch, max_epochs),
                               suffix='Loss: %.6f' % loss.item())
        print(train_loss.avg)

        with torch.no_grad():
            for i, (x_val, y_val) in enumerate(val_loader):
                x_val, y_val = x_val.to(device), y_val.to(device)
                model.eval()
                yhat = model(x_val)
                loss = dsc_loss(yhat, y_val)
                val_loss.update(loss.item(), x_val.size(0))
            print(val_loss)
            logg_file.save([epoch, train_loss.avg, val_loss.avg])

            # Save the best model with minimum validation loss
            if best_loss > val_loss.avg:
                print('Updated model with validation loss %.6f ---> %.6f' %
                      (best_loss, val_loss.avg))
                best_loss = val_loss.avg
                torch.save(model, './model_' + axis + '/best_model.pt')
def train(train_loader,
          model,
          criterion,
          optimizer,
          epoch,
          total_epochs=-1,
          performance_stats={},
          verbose=True,
          print_freq=10,
          tensorboard_log_function=None,
          tensorboard_stats=['train_loss']):
    '''
    Trains for one epoch.

    x, y are input and target
    y_hat is the predicted output

    performance_stats is a dictionary of name:function pairs
    where the function calculates some performance score from y and
    y_hat

    see the docs for the 'display_training_stats' function for
    info on verbose, print_freq, tensorboard_log_function, and
    tensorboard_stats
    '''

    base_stats = {'batch_time': AverageMeter(), 'train_loss': AverageMeter()}
    other_stats = {name: AverageMeter() for name in performance_stats.keys()}
    stats = {**base_stats, **other_stats}

    # enter training mode
    model.train()

    # begin timing the epoch
    stopwatch = time.time()

    # iterate over the batches of the epoch
    for i, (x, y) in enumerate(train_loader):
        y = y.cuda(async=True)
        x = x.cuda()
        # wrap as Variables
        x_var = torch.autograd.Variable(x)
        y_var = torch.autograd.Variable(y)

        # forward pass
        y_hat = model(x_var)
        loss = criterion(y_hat, y_var)

        # track loss and performance stats
        stats['train_loss'].update(loss.data[0], x.size(0))
        for stat_name, stat_func in performance_stats.items():
            stats[stat_name].update(stat_func(y_hat.data, y), x.size(0))

        # compute gradient and do backwards pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # track batch time
        stats['batch_time'].update(time.time() - stopwatch)
        stopwatch = time.time()

        # display progress
        if verbose:
            print_stats('train',
                        stats,
                        i,
                        len(train_loader),
                        epoch,
                        total_epochs,
                        print_freq=print_freq)

    # print results
    if verbose:
        print_stats('train',
                    stats,
                    len(train_loader),
                    len(train_loader),
                    epoch,
                    total_epochs,
                    print_freq=1)

    if tensorboard_log_function is not None:
        stats_to_log = {
            k: v
            for k, v in stats.items() if k in tensorboard_stats
        }
        log_stats_to_tensorboard(stats_to_log, tensorboard_log_function, epoch)
Example #17
0
def main():
    # CUDA for PyTorch
    use_cuda = torch.cuda.is_available()
    device = torch.device('cuda:0' if use_cuda else 'cpu')
    torch.backends.cudnn.benchmark = True

    train_params = {'batch_size': 50, 'shuffle': True, 'num_workers': 4}
    valid_params = {'batch_size': 100, 'shuffle': True, 'num_workers': 4}

    # Load dataset
    data_path = '../generated_data/'
    my_dataset = Dataset(data_path,
                         transform=transforms.Compose([Preprocessing()]))

    lengths = [int(len(my_dataset) * 0.8), int(len(my_dataset) * 0.2)]
    train_dataset, val_dataset = random_split(my_dataset, lengths)

    train_loader = torch.utils.data.DataLoader(train_dataset, **train_params)
    val_loader = torch.utils.data.DataLoader(val_dataset, **valid_params)

    # Training params
    learning_rate = 1e-3
    max_epochs = 4

    # Model
    model = unet.ResUNet(2, 1, n_size=16)
    model.cuda()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    train_loss = AverageMeter('Training loss', ':.6f')
    val_loss = AverageMeter('Validation loss', ':.6f')
    best_loss = float('inf')

    nb_of_batches = lengths[0] // train_params['batch_size']
    # Training loop
    for epoch in range(max_epochs):
        if not epoch:
            logg_file = loggs.Loggs(['epoch', 'train_loss', 'val_loss'])
        for i, (x_batch, y_labels) in enumerate(train_loader):
            x_batch, y_labels = x_batch.to(device), y_labels.to(device)
            y_pred = model(x_batch)
            #y_pred = torch.round(y_pred[0])
            loss = dice_loss(y_pred, y_labels)
            train_loss.update(loss.item(), x_batch.size(0))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            loggs.training_bar(i,
                               nb_of_batches,
                               prefix='Epoch: %d/%d' % (epoch, max_epochs),
                               suffix='Loss: %.6f' % loss.item())
        print(train_loss)

        with torch.no_grad():
            for i, (x_val, y_val) in enumerate(val_loader):
                x_val, y_val = x_val.to(device), y_val.to(device)
                model.eval()
                yhat = model(x_val)
                loss = dice_loss(yhat, y_val)
                val_loss.update(loss.item(), x_val.size(0))
                if i == 10: break
            print(val_loss)
            logg_file.save([epoch, train_loss.avg, val_loss.avg])

            # Save the best model with minimum validation loss
            if best_loss > val_loss.avg:
                print('Updated model with validation loss %.6f ---> %.6f' %
                      (best_loss, val_loss.avg))
                best_loss = val_loss.avg
                torch.save(model, 'best_model.pt')
Example #18
0
    def __init__(self, params: dict):

        self.params = params
        self.device = self.params.device

        self.game = params.game

        if self.game == "health":
            viz_env = "VizdoomHealthGathering-v0"
            self.load_path = "models/health/"
        elif self.game == "defend":
            viz_env = "VizdoomBasic-v0"
            self.load_path = "models/defend/"
        elif self.game == "center":
            viz_env = "VizdoomDefendCenter-v0"
            self.load_path = "models/center"

        # Initialize the environment
        self.env = gym.make(viz_env)
        self.num_actions = self.env.action_space.n

        # Intitialize both deep Q networks
        self.target_net = DQN(60, 80,
                              num_actions=self.num_actions).to(self.device)
        self.pred_net = DQN(60, 80,
                            num_actions=self.num_actions).to(self.device)

        self.optimizer = torch.optim.Adam(self.pred_net.parameters(), lr=2e-5)

        # load a pretrained model
        if self.params.load_model:

            checkpoint = torch.load(self.load_path + "full_model.pk",
                                    map_location=torch.device(self.device))

            self.pred_net.load_state_dict(checkpoint["model_state_dict"])

            self.optimizer.load_state_dict(
                checkpoint["optimizer_state_dict"], )

            self.replay_memory = checkpoint["replay_memory"]
            self.steps = checkpoint["steps"]
            self.learning_steps = checkpoint["learning_steps"]
            self.losses = checkpoint["losses"]
            self.frame_stack = checkpoint["frame_stack"]
            self.params = checkpoint["params"]
            self.params.start_decay = params.start_decay
            self.params.end_decay = params.end_decay
            self.episode = checkpoint["episode"]
            self.epsilon = checkpoint["epsilon"]
            self.stack_size = self.params.stack_size

        # training from scratch
        else:
            # weight init
            self.pred_net.apply(init_weights)

            # init replay memory
            self.replay_memory = ReplayMemory(10000)

            # init frame stack
            self.stack_size = self.params.stack_size
            self.frame_stack = deque(maxlen=self.stack_size)

            # track steps for target network update control
            self.steps = 0
            self.learning_steps = 0

            # loss logs
            self.losses = AverageMeter()

            self.episode = 0

            # epsilon decay parameters
            self.epsilon = self.params.eps_start

        # set target network to prediction network
        self.target_net.load_state_dict(self.pred_net.state_dict())
        self.target_net.eval()

        # move models to GPU
        if self.device == "cuda:0":
            self.target_net = self.target_net.to(self.device)
            self.pred_net = self.pred_net.to(self.device)

        # epsilon decay
        self.epsilon_start = self.params.eps_start

        # tensorboard
        self.writer = SummaryWriter()
Example #19
0
def test(net, criterion, writer):
    n_iter = 0
    net.eval()
    with torch.no_grad():
        n_test_data = 3000
        n_test_cache = n_test_data
        ds_test = DataSource(dataset_path, n_test_cache, -1)
        idx = np.array(test_indices['idx'].tolist())
        ds_test.load(n_test_data, idx)
        n_test_data = len(ds_test.anchors)
        test_set = TrainingSet(restore, bandwidth)
        test_set.generateAll(ds_test)
        n_test_set = len(test_set)
        if n_test_set == 0:
            print("Empty test set. Aborting test.")
            return
        print("Total size of the test set: ", n_test_set)
        test_size = n_test_set
        test_loader = torch.utils.data.DataLoader(test_set,
                                                  batch_size=10,
                                                  shuffle=False,
                                                  num_workers=1,
                                                  pin_memory=True,
                                                  drop_last=False)
        anchor_poses = ds_test.anchor_poses
        positive_poses = ds_test.positive_poses
        assert len(anchor_poses) == len(positive_poses)

        test_accs = AverageMeter()
        test_pos_dist = AverageMeter()
        test_neg_dist = AverageMeter()
        anchor_embeddings = np.empty(1)
        positive_embeddings = np.empty(1)
        for batch_idx, (data1, data2, data3) in enumerate(test_loader):
            embedded_a, embedded_p, embedded_n = net(data1.cuda().float(),
                                                     data2.cuda().float(),
                                                     data3.cuda().float())
            dist_to_pos, dist_to_neg, loss, loss_total = criterion(
                embedded_a, embedded_p, embedded_n)
            writer.add_scalar('Test/Loss', loss, n_iter)

            acc = accuracy(dist_to_pos, dist_to_neg)
            test_accs.update(acc, data1.size(0))
            test_pos_dist.update(dist_to_pos.cpu().data.numpy().sum())
            test_neg_dist.update(dist_to_neg.cpu().data.numpy().sum())

            writer.add_scalar('Test/Accuracy', test_accs.avg, n_iter)
            writer.add_scalar('Test/Distance/Positive', test_pos_dist.avg,
                              n_iter)
            writer.add_scalar('Test/Distance/Negative', test_neg_dist.avg,
                              n_iter)

            anchor_embeddings = np.append(
                anchor_embeddings,
                embedded_a.cpu().data.numpy().reshape([1, -1]))
            positive_embeddings = np.append(
                positive_embeddings,
                embedded_p.cpu().data.numpy().reshape([1, -1]))
            n_iter = n_iter + 1

        desc_anchors = anchor_embeddings[1:].reshape(
            [test_size, descriptor_size])
        desc_positives = positive_embeddings[1:].reshape(
            [test_size, descriptor_size])

        sys.setrecursionlimit(50000)
        tree = spatial.KDTree(desc_positives)
        p_norm = 2
        max_pos_dist = 0.05
        max_loc_dist = 5.0
        max_anchor_dist = 1
        for n_nearest_neighbors in range(1, 21):
            loc_count = 0
            for idx in range(test_size):
                nn_dists, nn_indices = tree.query(desc_anchors[idx, :],
                                                  p=p_norm,
                                                  k=n_nearest_neighbors)
                nn_indices = [nn_indices
                              ] if n_nearest_neighbors == 1 else nn_indices

                for nn_i in nn_indices:
                    dist = spatial.distance.euclidean(
                        positive_poses[nn_i, 5:8], anchor_poses[idx, 5:8])
                    if (dist <= max_pos_dist):
                        loc_count = loc_count + 1
                        break

            loc_precision = (loc_count * 1.0) / test_size
            writer.add_scalar('Test/Precision/Localization', loc_precision,
                              n_nearest_neighbors)
Example #20
0
def main():
    # device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    device = torch.device(
        'cuda') if torch.cuda.is_available() else torch.device('cpu')

    parser = argparse.ArgumentParser(description='Training of HiDDeN nets')
    parser.add_argument('--hostname',
                        default=socket.gethostname(),
                        help='the  host name of the running server')
    # parser.add_argument('--size', '-s', default=128, type=int, help='The size of the images (images are square so this is height and width).')
    parser.add_argument('--data-dir',
                        '-d',
                        required=True,
                        type=str,
                        help='The directory where the data is stored.')
    parser.add_argument(
        '--runs_root',
        '-r',
        default=os.path.join('.', 'experiments'),
        type=str,
        help='The root folder where data about experiments are stored.')
    parser.add_argument('--batch-size',
                        '-b',
                        default=1,
                        type=int,
                        help='Validation batch size.')

    args = parser.parse_args()

    if args.hostname == 'ee898-System-Product-Name':
        args.data_dir = '/home/ee898/Desktop/chaoning/ImageNet'
        args.hostname = 'ee898'
    elif args.hostname == 'DL178':
        args.data_dir = '/media/user/SSD1TB-2/ImageNet'
    else:
        args.data_dir = '/workspace/data_local/imagenet_pytorch'
    assert args.data_dir

    print_each = 25

    completed_runs = [
        o for o in os.listdir(args.runs_root)
        if os.path.isdir(os.path.join(args.runs_root, o))
        and o != 'no-noise-defaults'
    ]

    print(completed_runs)

    write_csv_header = True
    current_run = args.runs_root
    print(f'Run folder: {current_run}')
    options_file = os.path.join(current_run, 'options-and-config.pickle')
    train_options, hidden_config, noise_config = utils.load_options(
        options_file)
    train_options.train_folder = os.path.join(args.data_dir, 'val')
    train_options.validation_folder = os.path.join(args.data_dir, 'val')
    train_options.batch_size = args.batch_size
    checkpoint, chpt_file_name = utils.load_last_checkpoint(
        os.path.join(current_run, 'checkpoints'))
    print(f'Loaded checkpoint from file {chpt_file_name}')

    noiser = Noiser(noise_config, device, 'jpeg')
    model = Hidden(hidden_config, device, noiser, tb_logger=None)
    utils.model_from_checkpoint(model, checkpoint)

    print('Model loaded successfully. Starting validation run...')
    _, val_data = utils.get_data_loaders(hidden_config, train_options)
    file_count = len(val_data.dataset)
    if file_count % train_options.batch_size == 0:
        steps_in_epoch = file_count // train_options.batch_size
    else:
        steps_in_epoch = file_count // train_options.batch_size + 1

    with torch.no_grad():
        noises = ['webp_10', 'webp_25', 'webp_50', 'webp_75', 'webp_90']
        for noise in noises:
            losses_accu = {}
            step = 0
            for image, _ in val_data:
                step += 1
                image = image.to(device)
                message = torch.Tensor(
                    np.random.choice(
                        [0, 1], (image.shape[0],
                                 hidden_config.message_length))).to(device)
                losses, (
                    encoded_images, noised_images,
                    decoded_messages) = model.validate_on_batch_specific_noise(
                        [image, message], noise=noise)
                if not losses_accu:  # dict is empty, initialize
                    for name in losses:
                        losses_accu[name] = AverageMeter()
                for name, loss in losses.items():
                    losses_accu[name].update(loss)
                if step % print_each == 0 or step == steps_in_epoch:
                    print(f'Step {step}/{steps_in_epoch}')
                    utils.print_progress(losses_accu)
                    print('-' * 40)

            # utils.print_progress(losses_accu)
            write_validation_loss(os.path.join(args.runs_root,
                                               'validation_run.csv'),
                                  losses_accu,
                                  noise,
                                  checkpoint['epoch'],
                                  write_header=write_csv_header)
            write_csv_header = False
Example #21
0
def validate(val_loader, criterion, model, writer, args, epoch, best_acc):
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()

    # switch to evaluate mode
    model.eval()

    with torch.no_grad():
        end = time.time()
        for i, (data, target) in enumerate(val_loader):
            if args.gpu is not None:  # TODO None?
                data = data.cuda(args.gpu, non_blocking=True)
                target = target.cuda(args.gpu, non_blocking=True)

            # compute output
            output = model(data)
            loss = criterion(output, target)

            # measure accuracy and record loss
            acc1, _ = accuracy(output, target, topk=(1, 5))
            losses.update(loss.item(), data.size(0))
            top1.update(acc1[0], data.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)

            if i % args.print_freq == 0:
                print('Test: [{0}/{1}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t'.format(
                          i,
                          len(val_loader),
                          batch_time=batch_time,
                          loss=losses,
                          top1=top1))
            end = time.time()

    print(' * Acc@1 {top1.avg:.3f} '.format(top1=top1))
    writer.add_scalar('Test/Acc', top1.avg, epoch)
    writer.add_scalar('Test/Loss', losses.avg, epoch)

    if top1.avg.item() > best_acc:
        print('new best_acc is {top1.avg:.3f}'.format(top1=top1))
        print('saving model {}'.format(args.save_model))
        torch.save(model.state_dict(), args.save_model)
    return top1.avg.item()
Example #22
0
    def validate(self, do_dump_images=False, save_checkpoint=False):
        """
        Run validation on the current network state.
        """

        print("-------------- Validation ------------------")

        batch_time = AverageMeter()
        data_time = AverageMeter()
        loss_total_am = AverageMeter()
        loss_loc_am = AverageMeter()
        loss_cls_am = AverageMeter()

        # switch to evaluate mode
        self.model.eval()

        detection_val_dump_dir = os.path.join(self.run_dir,
                                              'detection_val_dump')
        if do_dump_images:
            clean_dir(detection_val_dump_dir)

        iou_threshold_perclass = [
            0.7 if i == 0 else 0.5 for i in range(len(self.model.labelmap))
        ]  # Kitti

        ap_estimator = average_precision.AveragePrecision(
            self.model.labelmap, iou_threshold_perclass)

        end = time.time()
        for batch_idx, sample in enumerate(self.val_loader):
            # Measure data loading time
            data_time.update(time.time() - end)

            input, target, names, pil_images, annotations, stats = sample

            with torch.no_grad():
                input_var, target_var = self.wrap_sample_with_variable(
                    input, target, volatile=True)

                # Compute output tensor of the network
                encoded_tensor = self.model_dp(input_var)
                # Compute loss for logging only
                _, loss_details = self.model.get_loss(encoded_tensor,
                                                      target_var)

            # Save annotation and detection results for further AP calculation
            class_grouped_anno = self.to_class_grouped_anno(annotations)
            detections_all = self.model.get_detections(encoded_tensor, 0.0)
            ap_estimator.add_batch(class_grouped_anno, detections_all)

            # Record loss
            loss_total_am.update(loss_details["loss"], input.size(0))
            loss_loc_am.update(loss_details["loc_loss"], input.size(0))
            loss_cls_am.update(loss_details["cls_loss"], input.size(0))

            # Dump validation images with overlays for developer to subjectively estimate accuracy
            if do_dump_images:
                overlay_conf_threshold = 0.3
                detections_thr = self.model.get_detections(
                    encoded_tensor, overlay_conf_threshold)
                dump_images(names, pil_images, annotations, detections_thr,
                            stats, self.model.labelmap, detection_val_dump_dir)

            # Measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if batch_idx % self.print_freq == 0:
                print(
                    'Validation: [{0}/{1}]\t'
                    'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                    'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                    'Loss {loss_total_am.val:.4f} ({loss_total_am.avg:.4f})\t'
                    'Loss_loc {loss_loc_am.val:.4f} ({loss_loc_am.avg:.4f})\t'
                    'Loss_cls {loss_cls_am.val:.4f} ({loss_cls_am.avg:.4f})\t'.
                    format(batch_idx,
                           len(self.val_loader),
                           batch_time=batch_time,
                           data_time=data_time,
                           loss_total_am=loss_total_am,
                           loss_loc_am=loss_loc_am,
                           loss_cls_am=loss_cls_am))

        # After coming over the while validation set, calculate individual average precision values and total mAP
        mAP, AP_list = ap_estimator.calculate_mAP()

        for ap, label in zip(AP_list, self.model.labelmap):
            print('{} {:.3f}'.format(label.ljust(20), ap))
        print('   mAP - {mAP:.3f}'.format(mAP=mAP))
        performance_metric = AP_list[self.model.labelmap.index('Car')]

        # Log to tensorboard
        if self.writer is not None:
            self.writer.add_scalar('val/mAP', mAP, self.train_iter)
            self.writer.add_scalar('val/performance_metric',
                                   performance_metric, self.train_iter)
            self.writer.add_scalar('val/loss', loss_total_am.avg,
                                   self.train_iter)
            self.writer.add_scalar('val/loss_loc', loss_loc_am.avg,
                                   self.train_iter)
            self.writer.add_scalar('val/loss_cls', loss_cls_am.avg,
                                   self.train_iter)

        if save_checkpoint:
            # Remember best accuracy and save checkpoint
            is_best = performance_metric > self.best_performance_metric

            if is_best:
                self.best_performance_metric = performance_metric
                torch.save({'state_dict': self.model.state_dict()},
                           self.snapshot_path)

        pass
Example #23
0
def train(train_loader, criterion, optimizer, epoch, model, writer, mask, args,
          conv_weights):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()

    # switch to train mode
    model.train()

    end = time.time()
    for i, (data, target) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)

        if args.gpu is not None:  # TODO None?
            data = data.cuda(args.gpu, non_blocking=True)
            target = target.cuda(args.gpu, non_blocking=True)

        output = model(data)

        loss = criterion(output, target)

        acc1, _ = accuracy(output, target, topk=(1, 5))

        losses.update(loss.item(), data.size(0))
        top1.update(acc1[0], data.size(0))

        optimizer.zero_grad()

        loss.backward()

        S1, S2 = args.S1, args.S2
        if args.repr and any(s1 <= epoch < s1 + S2
                             for s1 in range(S1, args.epochs, S1 + S2)):
            if i == 0:
                print('freeze for this epoch')
            with torch.no_grad():
                for name, W in conv_weights:
                    W.grad[mask[name]] = 0

        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)

        if i % args.print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                  'LR {lr:.3f}\t'.format(epoch,
                                         i,
                                         len(train_loader),
                                         batch_time=batch_time,
                                         data_time=data_time,
                                         loss=losses,
                                         top1=top1,
                                         lr=optimizer.param_groups[0]['lr']))

        end = time.time()
    writer.add_scalar('Train/Acc', top1.avg, epoch)
    writer.add_scalar('Train/Loss', losses.avg, epoch)