Example #1
0
def main():

    test_transform = trn.Compose([
        trn.ToTensor(),
        trn.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
    ])

    if args.dataset == 'cifar10':
        print("Using CIFAR 10")
        test_data = dset.CIFAR10('/data/sauravkadavath/cifar10-dataset',
                                 train=False,
                                 transform=test_transform)
        num_classes = 10
    else:
        print("Using CIFAR100")
        test_data = dset.CIFAR100('/data/sauravkadavath/cifar10-dataset',
                                  train=False,
                                  transform=test_transform)
        num_classes = 100

    test_loader = torch.utils.data.DataLoader(test_data,
                                              batch_size=args.test_bs,
                                              shuffle=False,
                                              num_workers=args.prefetch,
                                              pin_memory=True)

    backbone = WideResNet(args.layers,
                          args.protodim,
                          args.widen_factor,
                          dropRate=args.droprate)
    net = ProtoWRN(backbone, num_classes, args.protodim)
    net.cuda()
    net.load_state_dict(torch.load(args.load))

    test(net, state, test_loader)
Example #2
0
def setup_model(name, num_classes):
    if name == "WRN-28-2":
        model = WideResNet(num_classes=num_classes)
    else:
        if name in tv_models.__dict__:
            fn = tv_models.__dict__[name]
        else:
            raise RuntimeError("Unknown model name {}".format(name))
        model = fn(num_classes=num_classes)

    return model
def main():

    test_transform = trn.Compose([
        trn.ToTensor(),
        trn.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
    ])

    if args.in_dataset == 'cifar10' and args.out_dataset == 'cifar100':
        print("Using CIFAR 10 as in dataset")
        in_data = dset.CIFAR10('/data/sauravkadavath/cifar10-dataset', train=False, transform=test_transform)
        out_data = dset.CIFAR100('/data/sauravkadavath/cifar10-dataset', train=False, transform=test_transform)
        num_classes = 10
    elif args.in_dataset == 'cifar100' and args.out_dataset == 'cifar10':
        print("Using CIFAR100 as in dataset")
        in_data = dset.CIFAR100('/data/sauravkadavath/cifar10-dataset', train=False, transform=test_transform)
        out_data = dset.CIFAR10('/data/sauravkadavath/cifar10-dataset', train=False, transform=test_transform)
        num_classes = 100
    else:
        raise NotImplementedError

    in_loader = torch.utils.data.DataLoader(
        in_data,
        batch_size=args.test_bs,
        shuffle=False,
        num_workers=args.prefetch,
        pin_memory=True
    )

    out_loader = torch.utils.data.DataLoader(
        out_data,
        batch_size=args.test_bs,
        shuffle=False,
        num_workers=args.prefetch,
        pin_memory=True
    )

    backbone = WideResNet(args.layers, args.protodim, args.widen_factor, dropRate=args.droprate)
    net = ProtoWRN(backbone, num_classes, args.protodim)
    net.cuda()
    net.load_state_dict(torch.load(args.load))

    in_results = test(net, in_loader)
    out_results = test(net, out_loader)

    AUROC = sk.roc_auc_score(
        [1 for _ in range(len(in_results))] + [0 for _ in range(len(out_results))],
        in_results + out_results
    )

    print("AUROC = ", AUROC)
Example #4
0
                                           batch_size=args.batch_size,
                                           shuffle=True,
                                           num_workers=args.prefetch,
                                           pin_memory=True)
test_loader = torch.utils.data.DataLoader(test_data,
                                          batch_size=args.test_bs,
                                          shuffle=False,
                                          num_workers=args.prefetch,
                                          pin_memory=True)

# Create model
if args.model == 'allconv':
    net = AllConvNet(num_classes)
else:
    net = WideResNet(args.layers,
                     num_classes,
                     args.widen_factor,
                     dropRate=args.droprate)

start_epoch = 0

# Restore model if desired
if args.load != '':
    for i in range(1000 - 1, -1, -1):
        model_name = os.path.join(
            args.load, args.dataset + '_' + args.model + '_baseline_epoch_' +
            str(i) + '.pt')
        if os.path.isfile(model_name):
            net.load_state_dict(torch.load(model_name))
            print('Model restored! Epoch:', i)
            start_epoch = i + 1
            break
Example #5
0
                                           batch_size=args.batch_size,
                                           shuffle=True,
                                           num_workers=args.prefetch,
                                           pin_memory=True)
test_loader = torch.utils.data.DataLoader(test_data,
                                          batch_size=args.test_bs,
                                          shuffle=False,
                                          num_workers=args.prefetch,
                                          pin_memory=True)

# Create model
if args.model == 'allconv':
    net = AllConvNet(1000)
else:
    net = WideResNet(args.layers,
                     1000,
                     args.widen_factor,
                     dropRate=args.droprate)

start_epoch = 0

# Restore model if desired
if args.load != '':
    for i in range(1000 - 1, -1, -1):
        model_name = os.path.join(
            args.load,
            'imagenet_' + args.model + '_baseline_epoch_' + str(i) + '.pt')
        if os.path.isfile(model_name):
            net.load_state_dict(torch.load(model_name))
            print('Model restored! Epoch:', i)
            start_epoch = i + 1
            break
Example #6
0
    # train_data.train_data = np.copy(train_data.train_data[train_indices])
    # #


train_loader = torch.utils.data.DataLoader(
    train_data, batch_size=args.batch_size, shuffle=True,
    num_workers=args.prefetch, pin_memory=True)
test_loader = torch.utils.data.DataLoader(
    test_data, batch_size=args.test_bs, shuffle=False,
    num_workers=args.prefetch, pin_memory=True)

# Create model
if args.model == 'allconv':
    net = AllConvNet(args.num_classes_pretrained_net)
else:
    net = WideResNet(args.layers, args.num_classes_pretrained_net, args.widen_factor, dropRate=args.droprate)

net = nn.DataParallel(net)

model_found = False
# Restore model if desired
if args.load != '':
    for i in range(1000 - 1, -1, -1):
        model_name = os.path.join(args.load, 'imagenet_' + args.model +
                                  '_baseline_epoch_' + str(i) + '.pt')
        if os.path.isfile(model_name):
            net.load_state_dict(torch.load(model_name))
            print('Model restored! Epoch:', i)
            start_epoch = i + 1
            model_found = True
            break
Example #7
0
train_loader = torch.utils.data.DataLoader(
    train_data,
    batch_size=opt.batch_size,
    shuffle=True,
    num_workers=opt.prefetch,
    pin_memory=torch.cuda.is_available())
test_loader = torch.utils.data.DataLoader(test_data,
                                          batch_size=opt.test_bs,
                                          shuffle=False,
                                          num_workers=opt.prefetch,
                                          pin_memory=torch.cuda.is_available())

# Create model
if opt.model == 'wrn':
    net = WideResNet(opt.layers,
                     num_classes,
                     opt.widen_factor,
                     dropRate=opt.droprate)
else:
    assert False, opt.model + ' is not supported.'

start_epoch = opt.start_epoch

if opt.ngpu > 0:
    net = torch.nn.DataParallel(net, device_ids=list(range(opt.ngpu)))
    net.cuda()
    torch.cuda.manual_seed(opt.random_seed)

# Restore model if desired
if opt.load != '':
    if opt.test and os.path.isfile(opt.load):
        net.load_state_dict(torch.load(opt.load))
    def __init__(self,
                 root='~/home-nfs/dan/cifar_data',
                 train=True,
                 gold=True,
                 gold_fraction=0.1,
                 corruption_prob=0,
                 corruption_type='unif',
                 transform=None,
                 target_transform=None,
                 download=False,
                 shuffle_indices=None,
                 distinguish_gold=True,
                 seed=1):
        self.root = root
        self.transform = transform
        self.target_transform = target_transform
        self.train = train  # training set or test set
        self.gold = gold
        self.gold_fraction = gold_fraction
        self.corruption_prob = corruption_prob

        if download:
            self.download()

        if not self._check_integrity():
            raise RuntimeError('Dataset not found or corrupted.' +
                               ' You can use download=True to download it')

        # now load the picked numpy arrays
        if self.train:
            self.train_data = []
            self.train_labels = []
            self.train_coarse_labels = []
            for fentry in self.train_list:
                f = fentry[0]
                file = os.path.join(root, self.base_folder, f)
                fo = open(file, 'rb')
                if sys.version_info[0] == 2:
                    entry = pickle.load(fo)
                else:
                    entry = pickle.load(fo, encoding='latin1')
                self.train_data.append(entry['data'])
                if 'labels' in entry:
                    self.train_labels += entry['labels']
                    num_classes = 10
                else:
                    self.train_labels += entry['fine_labels']
                    self.train_coarse_labels += entry['coarse_labels']
                    num_classes = 100
                fo.close()

            self.train_data = np.concatenate(self.train_data)
            self.train_data = self.train_data.reshape((50000, 3, 32, 32))
            self.train_data = self.train_data.transpose(
                (0, 2, 3, 1))  # convert to HWC

            if gold is True:
                if shuffle_indices is None:
                    indices = np.arange(50000)
                    shuffled_train_labels = self.train_labels
                    if self.gold_fraction >= 0.05:
                        while len(
                                set(shuffled_train_labels[:int(gold_fraction *
                                                               50000)])
                        ) < num_classes:
                            np.random.shuffle(indices)
                            shuffled_train_labels = list(
                                np.array(self.train_labels)[indices])
                    else:
                        gold_indices = []
                        for c in range(num_classes):
                            gold_indices.extend(
                                indices[np.asarray(self.train_labels) ==
                                        c][:int(self.gold_fraction * 50000 /
                                                num_classes)])
                        indices = np.array(
                            gold_indices +
                            list(set(range(50000)) - set(gold_indices)))
                else:
                    indices = shuffle_indices

                self.train_data = self.train_data[indices][:int(gold_fraction *
                                                                50000)]
                if distinguish_gold:
                    # this ad-hoc move is done so we can identify which examples are
                    # gold/trusted and which are silver/unstrusted
                    self.train_labels = list(
                        np.array(self.train_labels)[indices]
                        [:int(gold_fraction * 50000)] + num_classes)
                else:
                    self.train_labels = list(
                        np.array(
                            self.train_labels)[indices][:int(gold_fraction *
                                                             50000)])
                self.shuffle_indices = indices
            else:
                indices = np.arange(
                    len(self.train_data
                        )) if shuffle_indices is None else shuffle_indices
                self.train_data = self.train_data[indices][int(gold_fraction *
                                                               50000):]
                self.train_labels = list(
                    np.array(self.train_labels)[indices][int(gold_fraction *
                                                             50000):])
                if corruption_type == 'hierarchical':
                    self.train_coarse_labels = list(
                        np.array(self.train_coarse_labels)[indices]
                        [int(gold_fraction * 50000):])

                if corruption_type == 'unif':
                    C = uniform_mix_C(self.corruption_prob, num_classes)
                elif corruption_type == 'flip':
                    C = flip_labels_C(self.corruption_prob,
                                      num_classes,
                                      seed=seed)
                elif corruption_type == 'hierarchical':
                    assert num_classes == 100, 'You must use CIFAR-100 with the hierarchical corruption.'
                    coarse_fine = []
                    for i in range(20):
                        coarse_fine.append(set())
                    for i in range(len(self.train_labels)):
                        coarse_fine[self.train_coarse_labels[i]].add(
                            self.train_labels[i])
                    for i in range(20):
                        coarse_fine[i] = list(coarse_fine[i])

                    C = np.eye(num_classes) * (1 - corruption_prob)

                    for i in range(20):
                        tmp = np.copy(coarse_fine[i])
                        for j in range(len(tmp)):
                            tmp2 = np.delete(np.copy(tmp), j)
                            C[tmp[j], tmp2] += corruption_prob * 1 / len(tmp2)
                elif corruption_type == 'clabels':
                    net = WideResNet(40, num_classes, 2, dropRate=0.3).cuda()
                    model_name = './cifar{}_labeler'.format(num_classes)
                    net.load_state_dict(torch.load(model_name))
                    net.eval()
                else:
                    assert False, "Invalid corruption type '{}' given. Must be in {'unif', 'flip', 'hierarchical'}".format(
                        corruption_type)

                np.random.seed(seed)
                if corruption_type == 'clabels':
                    mean = [x / 255 for x in [125.3, 123.0, 113.9]]
                    std = [x / 255 for x in [63.0, 62.1, 66.7]]

                    test_transform = transforms.Compose([
                        transforms.ToTensor(),
                        transforms.Normalize(mean, std)
                    ])

                    # obtain sampling probabilities
                    sampling_probs = []
                    print('Starting labeling')

                    for i in range((len(self.train_labels) // 64) + 1):
                        current = self.train_data[i * 64:(i + 1) * 64]
                        current = [
                            Image.fromarray(current[i])
                            for i in range(len(current))
                        ]
                        current = torch.cat([
                            test_transform(current[i]).unsqueeze(0)
                            for i in range(len(current))
                        ],
                                            dim=0)

                        data = V(current).cuda()
                        logits = net(data)
                        smax = F.softmax(logits / 5)  # temperature of 1
                        sampling_probs.append(smax.data.cpu().numpy())

                    sampling_probs = np.concatenate(sampling_probs, 0)
                    print('Finished labeling 1')

                    new_labeling_correct = 0
                    argmax_labeling_correct = 0
                    for i in range(len(self.train_labels)):
                        old_label = self.train_labels[i]
                        new_label = np.random.choice(num_classes,
                                                     p=sampling_probs[i])
                        self.train_labels[i] = new_label
                        if old_label == new_label:
                            new_labeling_correct += 1
                        if old_label == np.argmax(sampling_probs[i]):
                            argmax_labeling_correct += 1
                    print('Finished labeling 2')
                    print('New labeling accuracy:',
                          new_labeling_correct / len(self.train_labels))
                    print('Argmax labeling accuracy:',
                          argmax_labeling_correct / len(self.train_labels))
                else:
                    for i in range(len(self.train_labels)):
                        self.train_labels[i] = np.random.choice(
                            num_classes, p=C[self.train_labels[i]])
                    self.corruption_matrix = C

        else:
            f = self.test_list[0][0]
            file = os.path.join(root, self.base_folder, f)
            fo = open(file, 'rb')
            if sys.version_info[0] == 2:
                entry = pickle.load(fo)
            else:
                entry = pickle.load(fo, encoding='latin1')
            self.test_data = entry['data']
            if 'labels' in entry:
                self.test_labels = entry['labels']
            else:
                self.test_labels = entry['fine_labels']
            fo.close()
            self.test_data = self.test_data.reshape((10000, 3, 32, 32))
            self.test_data = self.test_data.transpose(
                (0, 2, 3, 1))  # convert to HWC
                    help='Pre-fetching threads.')
args = parser.parse_args()

state = {k: v for k, v in args._get_kwargs()}
print(state)

torch.manual_seed(1)
np.random.seed(1)

if args.dataset == 'cifar10':
    num_classes = 10
else:
    num_classes = 100

backbone = WideResNet(args.layers,
                      args.protodim,
                      args.widen_factor,
                      dropRate=args.droprate)
net = ProtoWRN(backbone, num_classes, args.protodim)
net.cuda()
net.load_state_dict(torch.load(args.load))

if args.ngpu > 1:
    net.cuda()
if args.ngpu > 0:
    net.cuda()
    torch.cuda.manual_seed(1)

cudnn.benchmark = True  # fire on all cylinders

test_transform = trn.Compose(
    [trn.ToTensor(),
Example #10
0
def main():

    train_transform = trn.Compose([
        trn.RandomHorizontalFlip(),
        trn.RandomCrop(32, padding=4),
        trn.ToTensor(),
        trn.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
    ])

    test_transform = trn.Compose([
        trn.ToTensor(),
        trn.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
    ])

    if args.dataset == 'cifar10':
        print("Using CIFAR 10")
        train_data_in = dset.CIFAR10('/data/sauravkadavath/cifar10-dataset',
                                     train=True,
                                     transform=train_transform)
        test_data = dset.CIFAR10('/data/sauravkadavath/cifar10-dataset',
                                 train=False,
                                 transform=test_transform)
        num_classes = 10
    else:
        print("Using CIFAR100")
        train_data_in = dset.CIFAR100('/data/sauravkadavath/cifar10-dataset',
                                      train=True,
                                      transform=train_transform)
        test_data = dset.CIFAR100('/data/sauravkadavath/cifar10-dataset',
                                  train=False,
                                  transform=test_transform)
        num_classes = 100

    train_loader_in = torch.utils.data.DataLoader(train_data_in,
                                                  batch_size=args.batch_size,
                                                  shuffle=True,
                                                  num_workers=args.prefetch,
                                                  pin_memory=True)

    test_loader = torch.utils.data.DataLoader(test_data,
                                              batch_size=args.batch_size,
                                              shuffle=False,
                                              num_workers=args.prefetch,
                                              pin_memory=True)

    net = WideResNet(args.layers,
                     num_classes,
                     args.widen_factor,
                     dropRate=args.droprate)
    net.cuda()

    optimizer = torch.optim.SGD(net.parameters(),
                                state['learning_rate'],
                                momentum=state['momentum'],
                                weight_decay=state['decay'],
                                nesterov=True)

    def cosine_annealing(step, total_steps, lr_max, lr_min):
        return lr_min + (lr_max - lr_min) * 0.5 * (
            1 + np.cos(step / total_steps * np.pi))

    lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
        optimizer,
        lr_lambda=lambda step: cosine_annealing(
            step,
            args.epochs * len(train_loader_in),
            1,  # since lr_lambda computes multiplicative factor
            1e-6 / args.learning_rate))

    # Make save directory
    if not os.path.exists(args.save):
        os.makedirs(args.save)
    if not os.path.isdir(args.save):
        raise Exception('%s is not a dir' % args.save)

        print('Beginning Training\n')

    with open(os.path.join(args.save, "training_log.csv"), 'w') as f:
        f.write()

    # Main loop
    for epoch in range(0, args.epochs):
        state['epoch'] = epoch

        begin_epoch = time.time()

        train(net, state, train_loader_in, optimizer, lr_scheduler)
        test(net, state, test_loader)

        # Save model
        torch.save(
            net.state_dict(),
            os.path.join(
                args.save,
                '{0}_{1}_layers_{2}_widenfactor_{3}_transform_epoch_{4}.pt'.
                format(args.dataset, args.model, str(args.layers),
                       str(args.widen_factor), str(epoch))))

        # Let us not waste space and delete the previous model
        prev_path = os.path.join(
            args.save,
            '{0}_{1}_layers_{2}_widenfactor_{3}_transform_epoch_{4}.pt'.format(
                args.dataset, args.model, str(args.layers),
                str(args.widen_factor), str(epoch - 1)))

        if os.path.exists(prev_path):
            os.remove(prev_path)

        # Show results
        print(
            'Epoch {0:3d} | Time {1:5d} | Train Loss {2:.4f} | Test Loss {3:.3f} | Test Error {4:.2f}'
            .format((epoch + 1), int(time.time() - begin_epoch),
                    state['train_loss'], state['test_loss'],
                    100 - 100. * state['test_accuracy']))
Example #11
0
        self.unfreeze(True)

    def unfreeze(self, unfreeze):
        for p in self.trunk.parameters():
            p.requires_grad = unfreeze

    def forward(self, x):
        x = self.trunk(x)
        x = F.avg_pool2d(x, 8)
        x = x.view(x.size(0), -1)
        return self.classifier(x)  #, x


# Init model, criterion, and optimizer
if args.use_pretrained_model:
    net = WideResNet(args.layers, 1000, args.widen_factor, dropRate=0)

    # net = nn.DataParallel(net)

    # Load pretrained model
    net.load_state_dict(
        torch.load('./snapshots/baseline/imagenet_wrn_baseline_epoch_99.pt'))

    # net = net.module

    net = FineTuneModel(net)
    state['learning_rate'] = 0.01
else:
    net = WideResNet(args.layers,
                     num_classes,
                     args.widen_factor,
Example #12
0
if 'shake' in args.model:
    from models.shake_shake import ResNeXt
    net = ResNeXt({
        'input_shape': (1, 3, 32, 32),
        'n_classes': num_classes,
        'base_channels': 96,
        'depth': 26,
        "shake_forward": True,
        "shake_backward": True,
        "shake_image": True
    })
    args.epochs = 500
    print('Overwriting epochs parameter; now the value is', args.epochs)
elif args.model == 'wrn' or 'wide' in args.model:
    from models.wrn import WideResNet
    net = WideResNet(16, num_classes, 4, dropRate=0.3)
    # args.decay = 5e-4
    # print('Overwriting decay parameter; now the value is', args.decay)
elif args.model == 'resnet':
    from models.resnet import ResNet
    net = ResNet({
        'input_shape': (1, 3, 32, 32),
        'n_classes': num_classes,
        'base_channels': 16,
        'block_type': 'basic',
        'depth': 20
    })
elif args.model == 'densenet':
    from models.densenet import DenseNet
    net = DenseNet({
        'input_shape': (1, 3, 32, 32),
Example #13
0
def main(index, args):
    if xm.is_master_ordinal():
        print(state)

    # Acquires the (unique) Cloud TPU core corresponding to this process's index
    xla_device = xm.xla_device()

    if args.dataset == 'cifar10':
        train_data = dset.CIFAR10('~/cifarpy/',
                                  train=True,
                                  download=True,
                                  transform=train_transform)
        test_data = dset.CIFAR10('~/cifarpy/',
                                 train=False,
                                 download=True,
                                 transform=test_transform)
        num_classes = 10
    else:
        train_data = dset.CIFAR100('~/cifarpy/',
                                   train=True,
                                   download=True,
                                   transform=train_transform)
        test_data = dset.CIFAR100('~/cifarpy/',
                                  train=False,
                                  download=True,
                                  transform=test_transform)
        num_classes = 100

    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_data,
        num_replicas=xm.xrt_world_size(),
        rank=xm.get_ordinal(),
        shuffle=True)

    test_sampler = torch.utils.data.distributed.DistributedSampler(
        test_data,
        num_replicas=xm.xrt_world_size(),
        rank=xm.get_ordinal(),
        shuffle=False)

    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=args.batch_size,
                                               num_workers=args.prefetch,
                                               drop_last=True,
                                               sampler=train_sampler)
    test_loader = torch.utils.data.DataLoader(test_data,
                                              batch_size=args.test_bs,
                                              num_workers=args.prefetch,
                                              drop_last=True,
                                              sampler=test_sampler)

    # Create model
    if args.model == 'wrn':
        net = WideResNet(args.layers,
                         num_classes,
                         args.widen_factor,
                         dropRate=args.droprate).train().to(xla_device)
    else:
        raise NotImplementedError()

    start_epoch = 0

    optimizer = torch.optim.SGD(net.parameters(),
                                args.learning_rate,
                                momentum=args.momentum,
                                weight_decay=args.decay,
                                nesterov=True)

    def cosine_annealing(step, total_steps, lr_max, lr_min):
        return lr_min + (lr_max - lr_min) * 0.5 * (
            1 + np.cos(step / total_steps * np.pi))

    scheduler = torch.optim.lr_scheduler.LambdaLR(
        optimizer,
        lr_lambda=lambda step: cosine_annealing(
            step,
            args.epochs * len(train_loader),
            1,  # since lr_lambda computes multiplicative factor
            1e-6 / args.learning_rate))

    print('Beginning Training')

    # Main loop
    for epoch in range(start_epoch, args.epochs):
        state['epoch'] = epoch

        begin_epoch = time.time()

        # Spawn a bunch of processes, one for each TPU core.
        train_loss = train(train_loader, net, optimizer, scheduler, xla_device,
                           args)

        # Calculate test loss
        test_results = test(test_loader, net, xla_device, args)

        # Save model. Does sync between all processes
        xm.save(
            net.state_dict(),
            os.path.join(
                args.save, args.dataset + args.model + '_baseline_epoch_' +
                str(epoch) + '.pt'))

        # Record stuff
        all_train_losses = xm.rendezvous("calc_train_loss",
                                         payload=str(train_loss))
        all_test_results = xm.rendezvous("calc_test_results",
                                         payload=str(test_results))
        all_test_results = parse_test_results(all_test_results)

        if xm.is_master_ordinal():
            all_train_losses = [float(L) for L in all_train_losses]
            train_loss = sum(all_train_losses) / float(len(all_train_losses))
            state['train_loss'] = train_loss

            test_loss = sum([r[0] for r in all_test_results]) / sum(
                [r[2] for r in all_test_results])
            test_acc = sum([r[1] for r in all_test_results]) / sum(
                [r[2] for r in all_test_results])
            state['test_loss'] = test_loss
            state['test_accuracy'] = test_acc

            # Let us not waste space and delete the previous model
            prev_path = os.path.join(
                args.save, args.dataset + args.model + '_baseline_epoch_' +
                str(epoch - 1) + '.pt')
            if os.path.exists(prev_path): os.remove(prev_path)

            # Show results
            with open(
                    os.path.join(
                        args.save, args.dataset + args.model +
                        '_baseline_training_results.csv'), 'a') as f:
                f.write('%03d,%05d,%0.6f,%0.5f,%0.2f\n' % (
                    (epoch + 1),
                    time.time() - begin_epoch,
                    state['train_loss'],
                    state['test_loss'],
                    100 - 100. * state['test_accuracy'],
                ))

            print(
                'Epoch {0:3d} | Time {1:5d} | Train Loss {2:.4f} | Test Loss {3:.3f} | Test Error {4:.2f}'
                .format((epoch + 1), int(time.time() - begin_epoch),
                        state['train_loss'], state['test_loss'],
                        100 - 100. * state['test_accuracy']))

            writer.add_scalar("test_loss", state["test_loss"], epoch + 1)
            writer.add_scalar("test_accuracy", state["test_accuracy"],
                              epoch + 1)

        # Wait for master to finish Disk I/O above
        print("Finished with one epoch")
        xm.rendezvous("epoch_finish")