Look inside DataBlock
Let's dig into DataBlock source code. By doing this, we could define our custom DataBlock class that fits to our problems.
dblock = DataBlock(blocks=(ImageBlock, CategoryBlock),
get_items=get_image_files,
get_y=parent_label,
item_tfms=Resize(224))
dblock.summary(path)
ImageBlock
def ImageBlock(cls=PILImage):
"A `TransformBlock` for images of `cls`"
return TransformBlock(type_tfms=cls.create, batch_tfms=IntToFloatTensor)
class TransformBlock():
"A basic wrapper that links defaults transforms for the data block API"
def __init__(self, type_tfms=None, item_tfms=None, batch_tfms=None, dl_type=None, dls_kwargs=None):
self.type_tfms = L(type_tfms)
self.item_tfms = ToTensor + L(item_tfms)
self.batch_tfms = L(batch_tfms)
self.dl_type,self.dls_kwargs = dl_type,({} if dls_kwargs is None else dls_kwargs)
class PILBase(Image.Image, metaclass=BypassNewMeta):
_bypass_type=Image.Image
_show_args = {'cmap':'viridis'}
_open_args = {'mode': 'RGB'}
@classmethod
def create(cls, fn:(Path,str,Tensor,ndarray,bytes), **kwargs)->None:
"Open an `Image` from path `fn`"
if isinstance(fn,TensorImage): fn = fn.permute(1,2,0).type(torch.uint8)
if isinstance(fn, TensorMask): fn = fn.type(torch.uint8)
if isinstance(fn,Tensor): fn = fn.numpy()
if isinstance(fn,ndarray): return cls(Image.fromarray(fn))
if isinstance(fn,bytes): fn = io.BytesIO(fn)
return cls(load_image(fn, **merge(cls._open_args, kwargs)))
def show(self, ctx=None, **kwargs):
"Show image using `merge(self._show_args, kwargs)`"
return show_image(self, ctx=ctx, **merge(self._show_args, kwargs))
def __repr__(self): return f'{self.__class__.__name__} mode={self.mode} size={"x".join([str(d) for d in self.size])}'
class PILImage(PILBase): pass
CategoryBlock
def CategoryBlock(vocab=None, sort=True, add_na=False):
"`TransformBlock` for single-label categorical targets"
return TransformBlock(type_tfms=Categorize(vocab=vocab, sort=sort, add_na=add_na))
class Categorize(DisplayedTransform):
"Reversible transform of category string to `vocab` id"
loss_func, order=CrossEntropyLossFlat(), 1
def __init__(self, vocab=None, sort=True, add_na=False):
if vocab is not None: vocab = CategoryMap(vocab, sort=sort, add_na=add_na)
store_attr()
def setups(self, dsets):
if self.vocab is None and dsets is not None:
self.vocab = CategoryMap(dsets, sort=self.sort, add_na=self.add_na)
self.c = len(self.vocab)
def encodes(self, o):
try:
return TensorCategory(self.vocab.o2i[o])
except KeyError as e:
raise KeyError(f"Label '{o}' was not included in the training dataset") from e
def decodes(self, o): return Category (self.vocab [o])
class CategoryMap(CollBase):
"Collection of categories with the reverse mapping in `o2i`"
def __init__(self, col, sort=True, add_na=False, strict=False):
if is_categorical_dtype(col):
items = L(col.cat.categories, use_list=True)
#Remove non-used categories while keeping order
if strict: items = L(o for o in items if o in col.unique())
else:
if not hasattr(col,'unique'): col = L(col, use_list=True)
# `o==o` is the generalized definition of non-NaN used by Pandas
items = L(o for o in col.unique() if o==o)
if sort: items = items.sorted()
self.items = '#na#' + items if add_na else items
self.o2i = defaultdict(int, self.items.val2idx()) if add_na else dict(self.items.val2idx())
def map_objs(self,objs):
"Map `objs` to IDs"
return L(self.o2i[o] for o in objs)
def map_ids(self,ids):
"Map `ids` to objects in vocab"
return L(self.items[o] for o in ids)
def __eq__(self,b): return all_equal(b,self)
PointBlock
PointBlock = TransformBlock(type_tfms=TensorPoint.create, item_tfms=PointScaler)
MaskBlock
def MaskBlock(codes=None):
"A `TransformBlock` for segmentation masks, potentially with `codes`"
return TransformBlock(type_tfms=PILMask.create, item_tfms=AddMaskCodes(codes=codes), batch_tfms=IntToFloatTensor)
class PILMask(PILBase): _open_args,_show_args = {'mode':'L'},{'alpha':0.5, 'cmap':'tab20'}
class AddMaskCodes(Transform):
"Add the code metadata to a `TensorMask`"
def __init__(self, codes=None):
self.codes = codes
if codes is not None: self.vocab,self.c = codes,len(codes)
def decodes(self, o:TensorMask):
if self.codes is not None: o._meta = {'codes': self.codes}
return o
BBoxBlock
BBoxBlock = TransformBlock(type_tfms=TensorBBox.create, item_tfms=PointScaler, dls_kwargs = {'before_batch': bb_pad})
class TensorBBox(TensorPoint):
"Basic type for a tensor of bounding boxes in an image"
@classmethod
def create(cls, x, img_size=None)->None: return cls(tensor(x).view(-1, 4).float(), img_size=img_size)
def show(self, ctx=None, **kwargs):
x = self.view(-1,4)
for b in x: _draw_rect(ctx, b, hw=False, **kwargs)
return ctx
class PointScaler(Transform):
"Scale a tensor representing points"
order = 1
def __init__(self, do_scale=True, y_first=False): self.do_scale,self.y_first = do_scale,y_first
def _grab_sz(self, x):
self.sz = [x.shape[-1], x.shape[-2]] if isinstance(x, Tensor) else x.size
return x
def _get_sz(self, x):
sz = x.get_meta('img_size')
assert sz is not None or self.sz is not None, "Size could not be inferred, pass it in the init of your TensorPoint with `img_size=...`"
return sz if self.sz is None else self.sz
def setups(self, dl):
its = dl.do_item(0)
for t in its:
if isinstance(t, TensorPoint): self.c = t.numel()
def encodes(self, x:(PILBase,TensorImageBase)): return self._grab_sz(x)
def decodes(self, x:(PILBase,TensorImageBase)): return self._grab_sz(x)
def encodes(self, x:TensorPoint): return _scale_pnts(x, self._get_sz(x), self.do_scale, self.y_first)
def decodes(self, x:TensorPoint): return _unscale_pnts(x.view(-1, 2), self._get_sz(x))
BBoxLblBlock
def BBoxLblBlock(vocab=None, add_na=True):
"A `TransformBlock` for labeled bounding boxes, potentially with `vocab`"
return TransformBlock(type_tfms=MultiCategorize(vocab=vocab, add_na=add_na), item_tfms=BBoxLabeler)
class BBoxLabeler(Transform):
def setups(self, dl): self.vocab = dl.vocab
def decode (self, x, **kwargs):
self.bbox,self.lbls = None,None
return self._call('decodes', x, **kwargs)
def decodes(self, x:TensorMultiCategory):
self.lbls = [self.vocab[a] for a in x]
return x if self.bbox is None else LabeledBBox(self.bbox, self.lbls)
def decodes(self, x:TensorBBox):
self.bbox = x
return self.bbox if self.lbls is None else LabeledBBox(self.bbox, self.lbls)
MultiCategoryBlock
def MultiCategoryBlock(encoded=False, vocab=None, add_na=False):
"`TransformBlock` for multi-label categorical targets"
tfm = EncodedMultiCategorize(vocab=vocab) if encoded else [MultiCategorize(vocab=vocab, add_na=add_na), OneHotEncode]
return TransformBlock(type_tfms=tfm)
TextBlock
class TextBlock(TransformBlock):
"A `TransformBlock` for texts"
@delegates(Numericalize.__init__)
def __init__(self, tok_tfm, vocab=None, is_lm=False, seq_len=72, backwards=False, **kwargs):
type_tfms = [tok_tfm, Numericalize(vocab, **kwargs)]
if backwards: type_tfms += [reverse_text]
return super().__init__(type_tfms=type_tfms,
dl_type=LMDataLoader if is_lm else SortedDL,
dls_kwargs={'seq_len': seq_len} if is_lm else {'before_batch': partial(pad_input_chunk, seq_len=seq_len)})
@classmethod
@delegates(Tokenizer.from_df, keep=True)
def from_df(cls, text_cols, vocab=None, is_lm=False, seq_len=72, backwards=False, min_freq=3, max_vocab=60000, **kwargs):
"Build a `TextBlock` from a dataframe using `text_cols`"
return cls(Tokenizer.from_df(text_cols, **kwargs), vocab=vocab, is_lm=is_lm, seq_len=seq_len,
backwards=backwards, min_freq=min_freq, max_vocab=max_vocab)
@classmethod
@delegates(Tokenizer.from_folder, keep=True)
def from_folder(cls, path, vocab=None, is_lm=False, seq_len=72, backwards=False, min_freq=3, max_vocab=60000, **kwargs):
"Build a `TextBlock` from a `path`"
return cls(Tokenizer.from_folder(path, **kwargs), vocab=vocab, is_lm=is_lm, seq_len=seq_len,
backwards=backwards, min_freq=min_freq, max_vocab=max_vocab)